Commit 65d9c407 authored by Phil Wang's avatar Phil Wang
Browse files

resolve issue with eos, maybe

parent d2932c75
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -194,6 +194,7 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
- [x] wire up sample hz from sound dataset -> transformers, and have proper resampling within during training - think about whether to allow for dataset to have sound files of varying or enforce same sample hz
- [x] full transformer training code for all three transformers
- [x] refactor so semantic transformer has a wrapper to that handles unique consecutives as well as wav to hubert or vq-wav2vec
- [x] simply not self attend to eos token on the prompting side (semantic for coarse transformer, coarse for fine transformer)

- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
@@ -205,7 +206,6 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
- [ ] simplify training even more within AudioLM class
- [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory
- [ ] validation function within audiolm that ensures all the pieces are compatible
- [ ] meditate on eos and refactor the entire mess so input never has eos, but eos manually added to labels for prompt sequence

## Citations

+17 −26
Original line number Diff line number Diff line
@@ -74,10 +74,10 @@ def top_k(logits, thres = 0.5):
    probs.scatter_(1, ind, val)
    return probs

def mask_out_after_eos_id(t, eos_id, mask_value = -1, include_eos = True):
def mask_out_after_eos_id(t, eos_id, mask_value = -1, keep_eos = True):
    eos_mask = (t == eos_id).float()

    if include_eos:
    if keep_eos:
        eos_mask = F.pad(eos_mask, (1, -1))

    after_eos_mask = eos_mask.cumsum(dim = -1) > 0
@@ -810,20 +810,7 @@ class SemanticTransformerWrapper(nn.Module):

            last_logit_indices += 1

        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, include_eos = include_eos_in_output)

        # ensure all sequences have eos

        has_eos_mask = (sample_semantic_ids == self.eos_id).any(dim = -1)

        if not has_eos_mask.all():
            append_eos_or_pad = torch.where(
                has_eos_mask,
                torch.full((batch, 1), self.pad_id, dtype = torch.long, device = device),
                torch.full((batch, 1), self.eos_id, dtype = torch.long, device = device),
            )

            sample_semantic_ids = torch.cat((sample_semantic_ids, append_eos_or_pad), dim = -1)
        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, keep_eos = False)

        return sample_semantic_ids

@@ -896,7 +883,8 @@ class CoarseTransformerWrapper(nn.Module):
        self.semantic_cross_entropy_loss_weight = semantic_cross_entropy_loss_weight

        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        self.eos_id = transformer.coarse_eos_id
        self.semantic_eos_id = transformer.semantic_eos_id
        self.coarse_eos_id = transformer.coarse_eos_id

    @property
    def device(self):
@@ -962,7 +950,7 @@ class CoarseTransformerWrapper(nn.Module):
                sampled = rearrange(sampled, 'b -> b 1')
                sampled_coarse_token_ids = torch.cat((sampled_coarse_token_ids, sampled), dim = -1)

        sampled_coarse_token_ids = mask_out_after_eos_id(sampled_coarse_token_ids, self.eos_id, include_eos = False)
        sampled_coarse_token_ids = mask_out_after_eos_id(sampled_coarse_token_ids, self.coarse_eos_id, keep_eos = False)

        if reshape_output or reconstruct_wave:
            sampled_coarse_token_ids = rearrange(sampled_coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers)
@@ -1018,12 +1006,13 @@ class CoarseTransformerWrapper(nn.Module):
            semantic_labels, coarse_labels = semantic_token_ids, coarse_token_ids.clone()
            coarse_token_ids = coarse_token_ids[:, :-1]

        self_attn_mask = None
        if self.unique_consecutive:
            self_attn_mask = semantic_token_ids != self.pad_id
        # self attention mask would omit any padding and eos tokens in the semantic prime

        self_attn_mask = (semantic_token_ids != self.pad_id) & (semantic_token_ids != self.semantic_eos_id)
        semantic_token_ids = semantic_token_ids.masked_fill(~self_attn_mask, 0)

        coarse_token_len = coarse_token_ids.shape[-1]
            self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True)
        self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True) # attend to semantic bos and all coarse tokens

        semantic_logits, coarse_logits = self.transformer(
            semantic_token_ids = semantic_token_ids,
@@ -1150,7 +1139,7 @@ class FineTransformerWrapper(nn.Module):
                sampled = rearrange(sampled, 'b -> b 1')
                sampled_fine_token_ids = torch.cat((sampled_fine_token_ids, sampled), dim = -1)

        sampled_fine_token_ids = mask_out_after_eos_id(sampled_fine_token_ids, self.eos_id, include_eos = False)
        sampled_fine_token_ids = mask_out_after_eos_id(sampled_fine_token_ids, self.eos_id, keep_eos = False)

        if reshape_output or reconstruct_wave:
            sampled_fine_token_ids = rearrange(sampled_fine_token_ids, 'b (n q) -> b n q', q = self.num_fine_quantizers)
@@ -1197,7 +1186,9 @@ class FineTransformerWrapper(nn.Module):
            coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
            fine_token_ids = fine_token_ids[:, :-1]

        self_attn_mask = coarse_token_ids != self.pad_id
        # do not attend to any of the coarse padding tokens or coarse end token either

        self_attn_mask = (coarse_token_ids != self.pad_id) & (coarse_token_ids != self.eos_id)
        coarse_token_ids = coarse_token_ids.masked_fill(~self_attn_mask, 0)

        fine_token_seq_len = fine_token_ids.shape[-1]
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.56',
  version = '0.0.57',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',