Commit 74c93f5a authored by Phil Wang's avatar Phil Wang
Browse files

take care of masking out fine token ids in the same sequence positions where...

take care of masking out fine token ids in the same sequence positions where coarse token ids are all padding, in preparation for returning list of variable lengthed waves
parent 6a77a50c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -206,6 +206,7 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
- [ ] simplify training even more within AudioLM class
- [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory
- [ ] validation function within audiolm that ensures all the pieces are compatible
- [ ] return a list of waves in the case of variable lengthed audio

## Citations

+34 −23
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from torch.autograd import grad as torch_grad
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, repeat
from einops import rearrange, repeat, reduce

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
@@ -902,7 +902,6 @@ class CoarseTransformerWrapper(nn.Module):
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        reshape_output = True,
        reconstruct_wave = False,
        **kwargs
    ):
@@ -951,17 +950,15 @@ class CoarseTransformerWrapper(nn.Module):
                sampled_coarse_token_ids = torch.cat((sampled_coarse_token_ids, sampled), dim = -1)

        sampled_coarse_token_ids = mask_out_after_eos_id(sampled_coarse_token_ids, self.coarse_eos_id, keep_eos = False)

        if reshape_output or reconstruct_wave:
        sampled_coarse_token_ids = rearrange(sampled_coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers)

        if reconstruct_wave:
        if not reconstruct_wave:
            return sampled_coarse_token_ids

        assert exists(self.soundstream)
            wav = self.soundstream.decode_from_codebook_indices(sampled_coarse_token_ids)
            wav = rearrange(wav, 'b 1 n -> b n')
            return wav

        return sampled_coarse_token_ids
        wav = self.soundstream.decode_from_codebook_indices(sampled_coarse_token_ids)
        return rearrange(wav, 'b 1 n -> b n')

    def forward(
        self,
@@ -1087,8 +1084,8 @@ class FineTransformerWrapper(nn.Module):
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        reshape_output = True,
        reconstruct_wave = False,
        mask_out_generated_fine_tokens = False,
        **kwargs
    ):
        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
@@ -1141,20 +1138,32 @@ class FineTransformerWrapper(nn.Module):

        sampled_fine_token_ids = mask_out_after_eos_id(sampled_fine_token_ids, self.eos_id, keep_eos = False)

        if reshape_output or reconstruct_wave:
        # reshape coarse and fine tokens for quantization dimension

        sampled_fine_token_ids = rearrange(sampled_fine_token_ids, 'b (n q) -> b n q', q = self.num_fine_quantizers)
        coarse_token_ids = rearrange(coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers)

        # whether to mask out fine token positions where the coarse token ids are all padding (variable lengthed training)

        if mask_out_generated_fine_tokens:
            pos_is_all_padding = (coarse_token_ids == self.pad_id).all(dim = -1, keepdim = True)
            seq_lengths = reduce(~pos_is_all_padding, 'b n 1 -> b', 'sum')

            sampled_fine_token_ids = sampled_fine_token_ids.masked_fill(pos_is_all_padding, self.pad_id)

        # if not reconstructing wave, return just the fine token ids

        if not reconstruct_wave:
            return sampled_fine_token_ids

        # reconstruct the wave using soundstream, concatting the fine and coarse token ids together first across quantization dimension

        if reconstruct_wave:
        assert exists(self.soundstream)

            coarse_token_ids = rearrange(coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers)
        coarse_and_fine_ids = torch.cat((coarse_token_ids, sampled_fine_token_ids), dim = -1)

        wav = self.soundstream.decode_from_codebook_indices(coarse_and_fine_ids)
            wav = rearrange(wav, 'b 1 n -> b n')
            return wav

        return sampled_fine_token_ids
        return rearrange(wav, 'b 1 n -> b n')

    def forward(
        self,
@@ -1272,7 +1281,8 @@ class AudioLM(nn.Module):
        text: Optional[List[str]] = None,
        prime_wave = None,
        max_length = 2048,
        return_coarse_generated_wave = False
        return_coarse_generated_wave = False,
        mask_out_generated_fine_tokens = False
    ):
        if exists(prime_wave):
            prime_wave = prime_wave.to(self.device)
@@ -1296,7 +1306,8 @@ class AudioLM(nn.Module):
        generated_wave = self.fine.generate(
            text = text,
            coarse_token_ids = coarse_token_ids_or_recon_wave,
            reconstruct_wave = True
            reconstruct_wave = True,
            mask_out_generated_fine_tokens = mask_out_generated_fine_tokens
        )

        return generated_wave
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.59',
  version = '0.0.60',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',