Commit 9c4c56d3 authored by Phil Wang's avatar Phil Wang
Browse files

make it work end to end

parent a17b189a
Loading
Loading
Loading
Loading
+24 −1
Original line number Diff line number Diff line
@@ -155,7 +155,29 @@ loss = train_wrapper(
loss.backward()
```

- [ ] show how to generate from prompt tensor or file
All together now

```python

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = transformer
)

generated_wav = audiolm(batch_size = 1)

# or with priming

generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))

# or with text condition, if given

generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])

```

## Appreciation

@@ -192,6 +214,7 @@ loss.backward()
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package
- [ ] add option to use flash attention
- [ ] simplify training even more within AudioLM class
- [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory

## Citations

+206 −41
Original line number Diff line number Diff line
import math
from functools import partial

from typing import Optional, Union
from typing import Optional, Union, List
from typeguard import typechecked

import torch
@@ -21,6 +21,8 @@ from torchaudio.functional import resample

from audiolm_pytorch.soundstream import SoundStream

from tqdm import tqdm

# helper functions

def exists(val):
@@ -356,7 +358,7 @@ class SemanticTransformer(nn.Module):
        self,
        *,
        max_length,
        text = None,
        text: Optional[List[str]] = None,
        text_embeds = None,
        prime_wave = None,
        prime_ids = None,
@@ -396,17 +398,17 @@ class SemanticTransformer(nn.Module):

        batch = ids.shape[0]
        start_length = ids.shape[-1]
        output = ids.clone()
        sample_semantic_ids = ids.clone()

        batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1')
        last_logit_indices = (ids != self.pad_id).sum(dim = -1).long()

        # sample from transformer

        for ind in range(start_length, max_length):
        for ind in tqdm(range(start_length, max_length), desc = 'generating semantic'):

            logits = self.forward_with_cond_scale(
                ids = output,
                ids = sample_semantic_ids,
                text_embeds = text_embeds,
                **kwargs
            )
@@ -419,18 +421,18 @@ class SemanticTransformer(nn.Module):
            sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sampled = rearrange(sampled, 'b -> b 1')
            output = torch.cat((output, sampled), dim = -1)
            sample_semantic_ids = torch.cat((sample_semantic_ids, sampled), dim = -1)

            if all_rows_have_eos_id(output, self.eos_id):
            if all_rows_have_eos_id(sample_semantic_ids, self.eos_id):
                break

            last_logit_indices += 1

        output = mask_out_after_eos_id(output, self.pad_id, include_eos = include_eos_in_output)
        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, include_eos = include_eos_in_output)

        # ensure all sequences have eos

        has_eos_mask = (output == self.eos_id).any(dim = -1)
        has_eos_mask = (sample_semantic_ids == self.eos_id).any(dim = -1)

        if not has_eos_mask.all():
            append_eos_or_pad = torch.where(
@@ -439,9 +441,9 @@ class SemanticTransformer(nn.Module):
                torch.full((batch, 1), self.eos_id, dtype = torch.long, device = device),
            )

            output = torch.cat((output, append_eos_or_pad), dim = -1)
            sample_semantic_ids = torch.cat((sample_semantic_ids, append_eos_or_pad), dim = -1)

        return output
        return sample_semantic_ids

    def forward_with_cond_scale(
        self,
@@ -463,7 +465,7 @@ class SemanticTransformer(nn.Module):
        raw_wave = None,
        ids = None,
        return_loss = False,
        text = None,
        text: Optional[List[str]] = None,
        text_embeds = None,
        cond_drop_prob = None
    ):
@@ -577,13 +579,19 @@ class CoarseTransformer(nn.Module):
        cond_scale = 3,
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
        semantic_logits, coarse_logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1 or not self.has_condition:
            return logits
            return semantic_logits, coarse_logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale
        null_semantic_logits, null_coarse_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)

        scaled_semantic_logits = None
        if exists(null_semantic_logits):
            scaled_semantic_logits = null_semantic_logits + (semantic_logits - null_semantic_logits) * cond_scale

        scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale
        return scaled_semantic_logits, scaled_coarse_logits

    def forward(
        self,
@@ -591,7 +599,7 @@ class CoarseTransformer(nn.Module):
        semantic_token_ids,
        coarse_token_ids,
        self_attn_mask = None,
        text = None,
        text: Optional[List[str]] = None,
        text_embeds = None,
        cond_drop_prob = None,
        return_only_coarse_logits = False
@@ -719,22 +727,29 @@ class FineTransformer(nn.Module):
        cond_scale = 3,
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
        coarse_logits, fine_logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1 or not self.has_condition:
            return logits
            return coarse_logits, fine_logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale
        null_coarse_logits, null_fine_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)

        scaled_coarse_logits = None
        if exists(null_coarse_logits):
            scaled_coarse_logits =  null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale

        scaled_fine_logits =  null_fine_logits + (fine_logits - null_fine_logits) * cond_scale
        return scaled_coarse_logits, scaled_fine_logits

    def forward(
        self,
        coarse_token_ids,
        fine_token_ids,
        text = None,
        text: Optional[List[str]] = None,
        text_embeds = None,
        cond_drop_prob = None,
        self_attn_mask = None
        self_attn_mask = None,
        return_only_fine_logits = False
    ):
        b, device = coarse_token_ids.shape[0], coarse_token_ids.device
        has_text = exists(text) or exists(text_embeds)
@@ -794,6 +809,9 @@ class FineTransformer(nn.Module):

        pred_coarse_tokens = rearrange(pred_coarse_tokens, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers)

        coarse_logits = None

        if not return_only_fine_logits:
            coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens)

            coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c')
@@ -858,12 +876,15 @@ class CoarseTransformerWrapper(nn.Module):
        self,
        *,
        semantic_token_ids,
        text: Optional[List[str]] = None,
        text_embeds = None,
        max_time_steps = 512,
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        reshape_output = True,
        reconstruct_wave = False
        reconstruct_wave = False,
        **kwargs
    ):
        batch, device = semantic_token_ids.shape[0], self.device

@@ -871,18 +892,31 @@ class CoarseTransformerWrapper(nn.Module):

        coarse_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long)

        # derive text embeddings if needed

        has_text = exists(text) or exists(text_embeds)
        assert not (self.transformer.has_condition ^ has_text)

        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.transformer.embed_text(text, output_device = device)

        # initialize

        init_coarse_time_step = coarse_token_ids.shape[-1]
        output = coarse_token_ids.clone()
        sampled_coarse_token_ids = coarse_token_ids.clone()

        for time_step in range(init_coarse_time_step, max_time_steps):
        for time_step in tqdm(range(init_coarse_time_step, max_time_steps), desc = 'generating coarse'):
            for ind in range(self.num_coarse_quantizers):
                is_last_step = ind == (self.num_coarse_quantizers - 1)

                _, coarse_logits = self.transformer.forward_with_cond_scale(
                    coarse_token_ids = coarse_token_ids,
                    semantic_token_ids = semantic_token_ids,
                    text_embeds = text_embeds,
                    cond_scale = cond_scale,
                    return_only_coarse_logits = True
                    return_only_coarse_logits = True,
                    **kwargs
                )

                last_coarse_logits = coarse_logits[:, -1]
@@ -894,20 +928,20 @@ class CoarseTransformerWrapper(nn.Module):
                sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

                sampled = rearrange(sampled, 'b -> b 1')
                output = torch.cat((output, sampled), dim = -1)
                sampled_coarse_token_ids = torch.cat((sampled_coarse_token_ids, sampled), dim = -1)

        output = mask_out_after_eos_id(output, self.eos_id, include_eos = False)
        sampled_coarse_token_ids = mask_out_after_eos_id(sampled_coarse_token_ids, self.eos_id, include_eos = False)

        if reshape_output or reconstruct_wave:
            output = rearrange(output, 'b (n q) -> b n q', q = self.num_coarse_quantizers)
            sampled_coarse_token_ids = rearrange(sampled_coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers)

        if reconstruct_wave:
            assert exists(self.soundstream)
            wav = self.soundstream.decode_from_codebook_indices(output)
            wav = self.soundstream.decode_from_codebook_indices(sampled_coarse_token_ids)
            wav = rearrange(wav, 'b 1 n -> b n')
            return wav

        return output
        return sampled_coarse_token_ids

    def forward(
        self,
@@ -998,11 +1032,98 @@ class FineTransformerWrapper(nn.Module):
        self.soundstream = soundstream
        self.transformer = transformer

        self.num_fine_quantizers = transformer.num_fine_quantizers
        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        self.eos_id = transformer.eos_id

        assert self.num_coarse_quantizers > 0

        self.pad_id = pad_id

    @property
    def device(self):
        return next(self.parameters()).device

    @eval_decorator
    @torch.no_grad()
    def generate(
        self,
        *,
        coarse_token_ids,
        text: Optional[List[str]] = None,
        text_embeds = None,
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        reshape_output = True,
        reconstruct_wave = False,
        **kwargs
    ):
        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')

        batch, device = coarse_token_ids.shape[0], self.device

        coarse_token_ids = coarse_token_ids.to(device)

        # derive text embeddings if needed

        has_text = exists(text) or exists(text_embeds)
        assert not (self.transformer.has_condition ^ has_text)

        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.transformer.embed_text(text, output_device = device)

        # initialize

        fine_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long)

        init_fine_time_step = fine_token_ids.shape[-1]
        max_time_steps = coarse_token_ids.shape[1] // self.num_coarse_quantizers

        sampled_fine_token_ids = fine_token_ids.clone()

        for time_step in tqdm(range(init_fine_time_step, max_time_steps), desc = 'generating fine'):
            for ind in range(self.num_fine_quantizers):
                is_last_step = ind == (self.num_fine_quantizers - 1)

                _, fine_logits = self.transformer.forward_with_cond_scale(
                    coarse_token_ids = coarse_token_ids,
                    fine_token_ids = fine_token_ids,
                    text_embeds = text_embeds,
                    cond_scale = cond_scale,
                    return_only_fine_logits = True,
                    **kwargs
                )

                last_fine_logits = fine_logits[:, -1]

                if not is_last_step:
                    last_fine_logits[:, -1] = float('-inf') # prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval

                filtered_logits = top_k(last_fine_logits, thres = filter_thres)
                sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

                sampled = rearrange(sampled, 'b -> b 1')
                sampled_fine_token_ids = torch.cat((sampled_fine_token_ids, sampled), dim = -1)

        sampled_fine_token_ids = mask_out_after_eos_id(sampled_fine_token_ids, self.eos_id, include_eos = False)

        if reshape_output or reconstruct_wave:
            sampled_fine_token_ids = rearrange(sampled_fine_token_ids, 'b (n q) -> b n q', q = self.num_fine_quantizers)

        if reconstruct_wave:
            assert exists(self.soundstream)

            coarse_token_ids = rearrange(coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers)
            coarse_and_fine_ids = torch.cat((coarse_token_ids, sampled_fine_token_ids), dim = -1)

            wav = self.soundstream.decode_from_codebook_indices(coarse_and_fine_ids)
            wav = rearrange(wav, 'b 1 n -> b n')
            return wav

        return sampled_fine_token_ids

    def forward(
        self,
        *,
@@ -1072,16 +1193,60 @@ class AudioLM(nn.Module):
    def __init__(
        self,
        *,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], 
        soundstream: SoundStream,
        semantic_transformer: SemanticTransformer,
        coarse_transformer: CoarseTransformer,
        fine_transformer: FineTransformer,
        fine_transformer: FineTransformer
    ):
        super().__init__()
        self.soundstream = soundstream
        self.semantic = semantic_transformer
        self.coarse = coarse_transformer
        self.fine = fine_transformer

    def forward(self, x):
        raise NotImplemented
        self.coarse = CoarseTransformerWrapper(
            wav2vec = wav2vec,
            soundstream = soundstream,
            transformer = coarse_transformer,
            unique_consecutive = semantic_transformer.unique_consecutive
        )

        self.fine = FineTransformerWrapper(
            soundstream = soundstream,
            transformer = fine_transformer
        )

    @property
    def device(self):
        return next(self.parameters()).device

    @eval_decorator
    @torch.no_grad()
    def forward(
        self,
        *,
        batch_size = 1,
        text: Optional[List[str]] = None,
        prime_wave = None,
        max_length = 2048
    ):
        if exists(prime_wave):
            prime_wave = prime_wave.to(self.device)

        semantic_token_ids = self.semantic.generate(
            text = text,
            batch_size = batch_size,
            prime_wave = prime_wave,
            max_length = max_length
        )

        coarse_token_ids = self.coarse.generate(
            text = text,
            semantic_token_ids = semantic_token_ids
        )

        generated_wave = self.fine.generate(
            text = text,
            coarse_token_ids = coarse_token_ids,
            reconstruct_wave = True
        )

        return generated_wave
+2 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.41',
  version = '0.0.42',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -26,6 +26,7 @@ setup(
    'torch>=1.6',
    'torchaudio',
    'transformers',
    'tqdm',
    'typeguard',
    'vector-quantize-pytorch>=0.10.11'
  ],