Commit 7e957b90 authored by Phil Wang's avatar Phil Wang
Browse files

add quantize embeddings to facilitate learning and some product management

parent 65495ad5
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -327,6 +327,7 @@ $ accelerate launch train.py
- [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory
- [ ] return a list of waves in the case of variable lengthed audio
- [ ] just take care of the edge case in coarse transformer text conditioned training, where the raw wave is resampled at different frequencies. autodetermine how to route based on length
- [ ] allow for specialized relative positional embeddings in fine transformer based on absolute matching positions of quantizers between coarse and fine

## Citations

+17 −0
Original line number Diff line number Diff line
@@ -603,7 +603,9 @@ class CoarseTransformer(nn.Module):

        self.coarse_eos_id = codebook_size
        codebook_size_with_eos = codebook_size + 1

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)
        self.coarse_quantize_embedding = nn.Embedding(num_coarse_quantizers, dim)

        text_dim = default(cond_dim, get_encoded_dim(t5_name))
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()
@@ -692,6 +694,10 @@ class CoarseTransformer(nn.Module):
        semantic_tokens = get_embeds(self.semantic_embedding, semantic_token_ids)
        coarse_tokens = self.coarse_embedding(coarse_token_ids)

        coarse_quantize_tokens = repeat(self.coarse_quantize_embedding.weight, 'q d -> (n q) d', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers))
        coarse_quantize_tokens = coarse_quantize_tokens[:coarse_token_ids.shape[-1], ...]
        coarse_tokens = coarse_tokens + coarse_quantize_tokens

        semantic_seq_len = semantic_tokens.shape[1]

        semantic_start_tokens = repeat(self.semantic_start_token, 'd -> b 1 d', b = b)
@@ -776,6 +782,9 @@ class FineTransformer(nn.Module):
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size, dim)

        self.coarse_quantize_embedding = nn.Embedding(num_coarse_quantizers, dim)
        self.fine_quantize_embedding = nn.Embedding(num_fine_quantizers, dim)

        self.eos_id = codebook_size

        text_dim = default(cond_dim, get_encoded_dim(t5_name))
@@ -870,6 +879,14 @@ class FineTransformer(nn.Module):
        coarse_tokens = self.coarse_embedding(coarse_token_ids)
        fine_tokens = self.fine_embedding(fine_token_ids)

        coarse_quantize_tokens = repeat(self.coarse_quantize_embedding.weight, 'q d -> (n q) d', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers))
        coarse_quantize_tokens = coarse_quantize_tokens[:coarse_token_ids.shape[-1], ...]
        coarse_tokens = coarse_tokens + coarse_quantize_tokens

        fine_quantize_tokens = repeat(self.fine_quantize_embedding.weight, 'q d -> (n q) d', n = ceil_div(fine_token_ids.shape[-1], self.num_fine_quantizers))
        fine_quantize_tokens = fine_quantize_tokens[:fine_token_ids.shape[-1], ...]
        fine_tokens = fine_tokens + fine_quantize_tokens

        coarse_start_tokens = repeat(self.coarse_start_token, 'd -> b 1 d', b = b)
        fine_start_tokens = repeat(self.fine_start_token, 'd -> b 1 d', b = b)

+1 −1
Original line number Diff line number Diff line
__version__ = '0.19.0'
__version__ = '0.19.1'