Loading README.md +5 −1 Original line number Diff line number Diff line Loading @@ -46,7 +46,11 @@ loss.backward() ## Todo - [ ] allow for cross attention based conditioning vs prefix based - [ ] complete full training code for soundstream, taking care of discriminator training - [ ] complete CoarseTransformer - [ ] complete sampling code for both Coarse and Fine Transformers, which will be tricky - [ ] accommodate variable lengthed audio, bring in eos token - [ ] full transformer training code for all three transformers ## Citations Loading audiolm_pytorch/audiolm_pytorch.py +28 −2 Original line number Diff line number Diff line Loading @@ -549,13 +549,39 @@ class SemanticTransformer(nn.Module): **kwargs ): super().__init__() self.start_token = nn.Parameter(torch.randn(dim)) self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim) self.transformer = Transformer(dim = dim, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens) def forward( self, semantic_token_ids ids, return_loss = False ): raise NotImplemented if return_loss: labels, ids = ids.clone(), ids[:, :-1] tokens = self.semantic_embedding(ids) start_tokens = repeat(self.start_token, 'd -> b 1 d', b = ids.shape[0]) tokens = torch.cat((start_tokens, tokens), dim = 1) tokens = self.transformer(tokens) logits = self.to_logits(tokens) if not return_loss: return logits loss = F.cross_entropy( rearrange(logits, 'b n c -> b c n'), labels ) return loss class CoarseTransformer(nn.Module): def __init__( Loading Loading
README.md +5 −1 Original line number Diff line number Diff line Loading @@ -46,7 +46,11 @@ loss.backward() ## Todo - [ ] allow for cross attention based conditioning vs prefix based - [ ] complete full training code for soundstream, taking care of discriminator training - [ ] complete CoarseTransformer - [ ] complete sampling code for both Coarse and Fine Transformers, which will be tricky - [ ] accommodate variable lengthed audio, bring in eos token - [ ] full transformer training code for all three transformers ## Citations Loading
audiolm_pytorch/audiolm_pytorch.py +28 −2 Original line number Diff line number Diff line Loading @@ -549,13 +549,39 @@ class SemanticTransformer(nn.Module): **kwargs ): super().__init__() self.start_token = nn.Parameter(torch.randn(dim)) self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim) self.transformer = Transformer(dim = dim, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens) def forward( self, semantic_token_ids ids, return_loss = False ): raise NotImplemented if return_loss: labels, ids = ids.clone(), ids[:, :-1] tokens = self.semantic_embedding(ids) start_tokens = repeat(self.start_token, 'd -> b 1 d', b = ids.shape[0]) tokens = torch.cat((start_tokens, tokens), dim = 1) tokens = self.transformer(tokens) logits = self.to_logits(tokens) if not return_loss: return logits loss = F.cross_entropy( rearrange(logits, 'b n c -> b c n'), labels ) return loss class CoarseTransformer(nn.Module): def __init__( Loading