Commit ddb40ee3 authored by Phil Wang's avatar Phil Wang
Browse files

give semantic transformer a training wrapper too, refactor and do things right

parent 96800b81
Loading
Loading
Loading
Loading
+30 −55
Original line number Diff line number Diff line
@@ -42,7 +42,7 @@ ex. `SemanticTransformer`

```python
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
@@ -50,30 +50,29 @@ wav2vec = HubertWithKmeans(
)

semantic_transformer = SemanticTransformer(
    wav2vec = wav2vec,
    num_semantic_tokens = 500,
    dim = 1024,
    depth = 6
).cuda()

wave = torch.randn(1, 320 * 512).cuda()

loss = semantic_transformer(
    raw_wave = wave,
    return_loss = True
trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = '/home/phil/dl/data/LibriSpeech',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

loss.backward()

# after much training above

sample = semantic_transformer.generate(max_length = 128) # (1, < 128) - may terminate early if it detects [eos]
trainer.train()
```

ex. `CoarseTransformer`

```python
import torch
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerWrapper
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
@@ -93,44 +92,31 @@ coarse_transformer = CoarseTransformer(
    depth = 6
)

coarse_wrapper = CoarseTransformerWrapper(
    wav2vec = wav2vec,
trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    transformer = coarse_transformer
).cuda()

wave = torch.randn(1, 32 * 320).cuda()

loss = coarse_wrapper(
    raw_wave = wave,
    return_loss = True
    wav2vec = wav2vec,
    folder = '/home/phil/dl/data/LibriSpeech',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 10000
)

loss.backward()

# after a lot of training

mock_semantic_token_ids = torch.randint(0, wav2vec.codebook_size, (1, 128))

coarse_tokens = coarse_wrapper.generate(
    semantic_token_ids = mock_semantic_token_ids,
    max_time_steps = 512
) # (1, 512, 3) - (batch, time steps, num quantizers)

trainer.train()
```

ex. `FineTransformer`

```python
import torch
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerWrapper
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerWrapper, FineTransformerTrainer

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load('/path/to/trained/soundstream.pt')
# soundstream.load('/path/to/trained/soundstream.pt')

transformer = FineTransformer(
    num_coarse_quantizers = 3,
@@ -140,28 +126,16 @@ transformer = FineTransformer(
    depth = 6
)

train_wrapper = FineTransformerWrapper(
trainer = FineTransformerTrainer(
    transformer = transformer,
    soundstream = soundstream,
    transformer = transformer
).cuda()

wave = torch.randn(1, 320 * 512).cuda()

loss = train_wrapper(
    raw_wave = wave,
    return_loss = True
    folder = '/home/phil/dl/data/LibriSpeech',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 10000
)

loss.backward()

# after a lot of training

mock_coarse_token_ids = torch.randint(0, 1024, (1, 128, 3))

fine_token_ids = train_wrapper.generate(
    coarse_token_ids = mock_coarse_token_ids
) # (1, 128, 5)

trainer.train()
```

All together now
@@ -217,6 +191,7 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
- [x] add efficient gradient penalty for discriminators for soundstream
- [x] wire up sample hz from sound dataset -> transformers, and have proper resampling within during training - think about whether to allow for dataset to have sound files of varying or enforce same sample hz
- [x] full transformer training code for all three transformers
- [x] refactor so semantic transformer has a wrapper to that handles unique consecutives as well as wav to hubert or vq-wav2vec

- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
@@ -227,8 +202,8 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
- [ ] add option to use flash attention
- [ ] simplify training even more within AudioLM class
- [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory
- [ ] refactor so semantic transformer has a wrapper to that handles unique consecutives as well as wav to hubert or vq-wav2vec
- [ ] validation function within audiolm that ensures all the pieces are compatible
- [ ] meditate on eos and refactor the entire mess so input never has eos, but eos manually added to labels for prompt sequence

## Citations

+1 −1
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from audiolm_pytorch.audiolm_pytorch import AudioLM
from audiolm_pytorch.soundstream import SoundStream

from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper, SemanticTransformerWrapper

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
+173 −142
Original line number Diff line number Diff line
@@ -315,145 +315,34 @@ class SemanticTransformer(nn.Module):
        self,
        *,
        dim,
        num_semantic_tokens = None,
        num_semantic_tokens,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        unique_consecutive = True,
        grad_shrink_alpha = 0.1,
        pad_id = -1,
        **kwargs
    ):
        super().__init__()
        assert exists(wav2vec) or exists(num_semantic_tokens)

        if exists(wav2vec):
            num_semantic_tokens = default(num_semantic_tokens, wav2vec.codebook_size)
            assert num_semantic_tokens == wav2vec.codebook_size
        self.num_semantic_tokens = num_semantic_tokens

        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.unique_consecutive = unique_consecutive

        self.start_token = nn.Parameter(torch.randn(dim))

        self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim)
        self.eos_id = num_semantic_tokens
        self.pad_id = pad_id

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    def non_wav2vec_parameters(self):
        return (
            set([*self.semantic_embedding.parameters()]) |
            set([self.start_token]) |
            set([*self.transformer.parameters()]) |
            set([*self.to_logits.parameters()])
        )

    @property
    def device(self):
        return next(self.parameters()).device

    @eval_decorator
    @torch.no_grad()
    def generate(
        self,
        *,
        max_length,
        text: Optional[List[str]] = None,
        text_embeds = None,
        prime_wave = None,
        prime_ids = None,
        batch_size = 1,
        cond_scale = 3,
        filter_thres = 0.9,
        temperature = 1.,
        include_eos_in_output = True,  # if doing hierarchical sampling, eos must be kept for an easy time
        **kwargs
    ):
        device = self.device

        # derive wav2vec ids from the input wave

        if exists(prime_wave):
            assert not exists(prime_ids)
            assert exists(self.wav2vec)
            ids = self.wav2vec(prime_wave, flatten = False)
        elif exists(prime_ids):
            ids = prime_ids
        else:
            ids = torch.empty((batch_size, 0), dtype = torch.long, device = device)

        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        # derive text embeddings if needed

        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)

        # start length and get running id output

        batch = ids.shape[0]
        start_length = ids.shape[-1]
        sample_semantic_ids = ids.clone()

        batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1')
        last_logit_indices = (ids != self.pad_id).sum(dim = -1).long()

        # sample from transformer

        for ind in tqdm(range(start_length, max_length), desc = 'generating semantic'):

            logits = self.forward_with_cond_scale(
                ids = sample_semantic_ids,
                text_embeds = text_embeds,
                unique_consecutive = False,
                **kwargs
            )

            last_logits = logits[batch_range, last_logit_indices]

            last_logits = rearrange(last_logits, 'b 1 c -> b c')

            filtered_logits = top_k(last_logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sampled = rearrange(sampled, 'b -> b 1')
            sample_semantic_ids = torch.cat((sample_semantic_ids, sampled), dim = -1)

            if all_rows_have_eos_id(sample_semantic_ids, self.eos_id):
                break

            last_logit_indices += 1

        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, include_eos = include_eos_in_output)

        # ensure all sequences have eos

        has_eos_mask = (sample_semantic_ids == self.eos_id).any(dim = -1)

        if not has_eos_mask.all():
            append_eos_or_pad = torch.where(
                has_eos_mask,
                torch.full((batch, 1), self.pad_id, dtype = torch.long, device = device),
                torch.full((batch, 1), self.eos_id, dtype = torch.long, device = device),
            )

            sample_semantic_ids = torch.cat((sample_semantic_ids, append_eos_or_pad), dim = -1)

        return sample_semantic_ids

    def forward_with_cond_scale(
        self,
        *args,
@@ -471,7 +360,6 @@ class SemanticTransformer(nn.Module):
    def forward(
        self,
        *,
        raw_wave = None,
        ids = None,
        return_loss = False,
        text: Optional[List[str]] = None,
@@ -480,22 +368,9 @@ class SemanticTransformer(nn.Module):
        unique_consecutive = None
    ):
        device = self.device
        unique_consecutive = default(unique_consecutive, self.unique_consecutive)

        assert exists(raw_wave) ^ exists(ids)

        if not exists(ids):
            assert exists(self.wav2vec)
            ids = self.wav2vec(raw_wave, flatten = False)

        b = ids.shape[0]

        if self.training:
            ids = append_eos_id(ids, self.eos_id)

        if unique_consecutive:
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)

@@ -521,18 +396,7 @@ class SemanticTransformer(nn.Module):
        tokens = torch.cat((start_tokens, tokens), dim = 1)

        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)
        logits = self.to_logits(tokens)

        if not return_loss:
            return logits

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels,
            ignore_index = self.pad_id
        )

        return loss
        return self.to_logits(tokens)

@typechecked
class CoarseTransformer(nn.Module):
@@ -855,6 +719,167 @@ class FineTransformer(nn.Module):

# training wrappers

class SemanticTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        transformer: SemanticTransformer,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        pad_id = -1,
        unique_consecutive = True
    ):
        super().__init__()
        self.wav2vec = wav2vec
        self.transformer = transformer
        assert self.wav2vec.codebook_size == transformer.num_semantic_tokens

        self.unique_consecutive = unique_consecutive
        self.pad_id = pad_id
        self.eos_id = transformer.eos_id

    @property
    def device(self):
        return next(self.parameters()).device

    @eval_decorator
    @torch.no_grad()
    def generate(
        self,
        *,
        max_length,
        text: Optional[List[str]] = None,
        text_embeds = None,
        prime_wave = None,
        prime_ids = None,
        batch_size = 1,
        cond_scale = 3,
        filter_thres = 0.9,
        temperature = 1.,
        include_eos_in_output = True,  # if doing hierarchical sampling, eos must be kept for an easy time
        **kwargs
    ):
        device = self.device

        # derive wav2vec ids from the input wave

        if exists(prime_wave):
            assert not exists(prime_ids)
            assert exists(self.wav2vec)
            ids = self.wav2vec(prime_wave, flatten = False)
        elif exists(prime_ids):
            ids = prime_ids
        else:
            ids = torch.empty((batch_size, 0), dtype = torch.long, device = device)

        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        # derive text embeddings if needed

        has_text = exists(text) or exists(text_embeds)
        assert not (self.transformer.has_condition ^ has_text)

        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.transformer.embed_text(text, output_device = device)

        # start length and get running id output

        batch = ids.shape[0]
        start_length = ids.shape[-1]
        sample_semantic_ids = ids.clone()

        batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1')
        last_logit_indices = (ids != self.pad_id).sum(dim = -1).long()

        # sample from transformer

        for ind in tqdm(range(start_length, max_length), desc = 'generating semantic'):

            logits = self.transformer.forward_with_cond_scale(
                ids = sample_semantic_ids,
                text_embeds = text_embeds,
                **kwargs
            )

            last_logits = logits[batch_range, last_logit_indices]

            last_logits = rearrange(last_logits, 'b 1 c -> b c')

            filtered_logits = top_k(last_logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sampled = rearrange(sampled, 'b -> b 1')
            sample_semantic_ids = torch.cat((sample_semantic_ids, sampled), dim = -1)

            if all_rows_have_eos_id(sample_semantic_ids, self.eos_id):
                break

            last_logit_indices += 1

        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, include_eos = include_eos_in_output)

        # ensure all sequences have eos

        has_eos_mask = (sample_semantic_ids == self.eos_id).any(dim = -1)

        if not has_eos_mask.all():
            append_eos_or_pad = torch.where(
                has_eos_mask,
                torch.full((batch, 1), self.pad_id, dtype = torch.long, device = device),
                torch.full((batch, 1), self.eos_id, dtype = torch.long, device = device),
            )

            sample_semantic_ids = torch.cat((sample_semantic_ids, append_eos_or_pad), dim = -1)

        return sample_semantic_ids

    def forward(
        self,
        *,
        semantic_token_ids = None,
        raw_wave = None,
        text = None,
        text_embeds = None,
        return_loss = False,
        **kwargs
    ):
        assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'

        if not exists(semantic_token_ids):
            assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
            semantic_token_ids = self.wav2vec(raw_wave, flatten = False)

        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')

        if self.training:
            semantic_token_ids = append_eos_id(semantic_token_ids, self.transformer.eos_id)

        if self.unique_consecutive:
            semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value = self.pad_id)

        input_ids = semantic_token_ids
        if return_loss:
            input_ids = semantic_token_ids[:, :-1]

        logits = self.transformer(
            ids = input_ids,
            text = text,
            text_embeds = text_embeds,
            **kwargs
        )

        if not return_loss:
            return logits

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            semantic_token_ids,
            ignore_index = self.pad_id
        )

        return loss

@typechecked
class CoarseTransformerWrapper(nn.Module):
    def __init__(
@@ -1227,16 +1252,22 @@ class AudioLM(nn.Module):
        soundstream: SoundStream,
        semantic_transformer: SemanticTransformer,
        coarse_transformer: CoarseTransformer,
        fine_transformer: FineTransformer
        fine_transformer: FineTransformer,
        unique_consecutive = True
    ):
        super().__init__()
        self.semantic = semantic_transformer

        self.semantic = SemanticTransformerWrapper(
            wav2vec = wav2vec,
            transformer = semantic_transformer,
            unique_consecutive = unique_consecutive
        )

        self.coarse = CoarseTransformerWrapper(
            wav2vec = wav2vec,
            soundstream = soundstream,
            transformer = coarse_transformer,
            unique_consecutive = semantic_transformer.unique_consecutive
            unique_consecutive = unique_consecutive
        )

        self.fine = FineTransformerWrapper(
+17 −13
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@ from audiolm_pytorch.soundstream import SoundStream

from audiolm_pytorch.audiolm_pytorch import (
    SemanticTransformer,
    SemanticTransformerWrapper,
    CoarseTransformer,
    CoarseTransformerWrapper,
    FineTransformer,
@@ -341,6 +342,7 @@ class SoundStreamTrainer(nn.Module):
class SemanticTransformerTrainer(nn.Module):
    def __init__(
        self,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        transformer: SemanticTransformer,
        *,
        num_train_steps,
@@ -361,8 +363,14 @@ class SemanticTransformerTrainer(nn.Module):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.wav2vec = wav2vec
        self.transformer = transformer

        self.train_wrapper = SemanticTransformerWrapper(
            wav2vec = wav2vec,
            transformer = transformer
        )

        self.register_buffer('steps', torch.Tensor([0]))

        self.num_train_steps = num_train_steps
@@ -371,7 +379,7 @@ class SemanticTransformerTrainer(nn.Module):

        # optimizers

        self.optim = get_optimizer(transformer.non_wav2vec_parameters(), lr = lr, wd = wd)
        self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd)

        # max grad norm

@@ -382,8 +390,8 @@ class SemanticTransformerTrainer(nn.Module):
        self.ds = SoundDataset(
            folder,
            max_length = data_max_length,
            target_sample_hz = transformer.wav2vec.target_sample_hz,
            seq_len_multiple_of = transformer.wav2vec.seq_len_multiple_of
            target_sample_hz = wav2vec.target_sample_hz,
            seq_len_multiple_of = wav2vec.seq_len_multiple_of
        )

        # split for validation
@@ -406,12 +414,12 @@ class SemanticTransformerTrainer(nn.Module):
        # prepare with accelerator

        (
            self.transformer,
            self.train_wrapper,
            self.optim,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.transformer,
            self.train_wrapper,
            self.optim,
            self.dl,
            self.valid_dl
@@ -467,14 +475,14 @@ class SemanticTransformerTrainer(nn.Module):
        for _ in range(self.grad_accum_every):
            wave = next(self.dl_iter).to(device)

            loss = self.transformer(raw_wave = wave, return_loss = True)
            loss = self.train_wrapper(raw_wave = wave, return_loss = True)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.non_wav2vec_parameters(), self.max_grad_norm)
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()
@@ -486,15 +494,11 @@ class SemanticTransformerTrainer(nn.Module):
        # sample results every so often

        if self.is_main and not (steps % self.save_results_every):
            model = self.transformer
            filename = str(steps)

            model.eval()

            wave = next(self.valid_dl_iter).to(device)

            with torch.no_grad():
                valid_loss = model(raw_wave = wave, return_loss = True)
                self.train_wrapper.eval()
                valid_loss = self.train_wrapper(raw_wave = wave, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')