Commit bead0a38 authored by Phil Wang's avatar Phil Wang
Browse files

add ability to train the soundstream to be denoising, as in the paper, may be...

add ability to train the soundstream to be denoising, as in the paper, may be needed for naturalspeech2
parent d05c020f
Loading
Loading
Loading
Loading
+45 −8
Original line number Diff line number Diff line
@@ -404,6 +404,15 @@ class LocalTransformer(nn.Module):

        return x

class FiLM(nn.Module):
    def __init__(self, dim, dim_cond):
        super().__init__()
        self.to_cond = nn.Linear(dim_cond, dim * 2)

    def forward(self, x, cond):
        gamma, beta = self.to_cond(cond).chunk(2, dim = -1)
        return x * gamma + beta

class SoundStream(nn.Module):
    def __init__(
        self,
@@ -487,6 +496,8 @@ class SoundStream(nn.Module):

        self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None

        self.encoder_film = FiLM(codebook_dim, dim_cond = 2)

        self.num_quantizers = rq_num_quantizers

        self.codebook_dim = codebook_dim
@@ -504,6 +515,8 @@ class SoundStream(nn.Module):
            quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        )

        self.decoder_film = FiLM(codebook_dim, dim_cond = 2)

        self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None

        decoder_blocks = []
@@ -570,6 +583,10 @@ class SoundStream(nn.Module):

        self.register_buffer('zero', torch.tensor([0.]), persistent = False)

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def configs(self):
        return pickle.loads(self._configs)
@@ -641,17 +658,33 @@ class SoundStream(nn.Module):
            *self.encoder.parameters(),
            *self.decoder.parameters(),
            *(self.encoder_attn.parameters() if exists(self.encoder_attn) else []),
            *(self.decoder_attn.parameters() if exists(self.decoder_attn) else [])
            *(self.decoder_attn.parameters() if exists(self.decoder_attn) else []),
            *self.encoder_film.parameters(),
            *self.decoder_film.parameters()
        ]

    @property
    def seq_len_multiple_of(self):
        return functools.reduce(lambda x, y: x * y, self.strides)

    def process_input(self, x, input_sample_hz = None):
        x, ps = pack([x], '* n')

        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

        x = curtail_to_multiple(x, self.seq_len_multiple_of)

        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')

        return x, ps

    def forward(
        self,
        x,
        target = None,
        is_denoising = None, # if you want to learn film conditioners that teach the soundstream to denoise - target would need to be passed in above
        return_encoded = False,
        return_discr_loss = False,
        return_discr_losses_separately = False,
@@ -660,15 +693,12 @@ class SoundStream(nn.Module):
        input_sample_hz = None,
        apply_grad_penalty = False
    ):
        x, ps = pack([x], '* n')
        assert not (exists(is_denoising) and not exists(target))

        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)
        x, ps = self.process_input(x, input_sample_hz = input_sample_hz)

        x = curtail_to_multiple(x, self.seq_len_multiple_of)

        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')
        if exists(target):
            target, _ = self.process_input(target, input_sample_hz = input_sample_hz)

        orig_x = x.clone()

@@ -679,11 +709,18 @@ class SoundStream(nn.Module):
        if exists(self.encoder_attn):
            x = self.encoder_attn(x)

        if exists(is_denoising):
            denoise_input = torch.tensor([is_denoising, not is_denoising], dtype = x.dtype, device = self.device) # [1, 0] for denoise, [0, 1] for not denoising
            x = self.encoder_film(x, denoise_input)

        x, indices, commit_loss = self.rq(x)

        if return_encoded:
            return x, indices, commit_loss

        if exists(is_denoising):
            x = self.decoder_film(x, denoise_input)

        if exists(self.decoder_attn):
            x = self.decoder_attn(x)

+1 −1
Original line number Diff line number Diff line
__version__ = '0.28.2'
__version__ = '0.29.0'