Commit fac3152e authored by Phil Wang's avatar Phil Wang
Browse files

complete semantic transformer, as it is a normal transformer

parent 0ec7667b
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -46,7 +46,11 @@ loss.backward()

## Todo

- [ ] allow for cross attention based conditioning vs prefix based
- [ ] complete full training code for soundstream, taking care of discriminator training
- [ ] complete CoarseTransformer
- [ ] complete sampling code for both Coarse and Fine Transformers, which will be tricky
- [ ] accommodate variable lengthed audio, bring in eos token
- [ ] full transformer training code for all three transformers

## Citations

+28 −2
Original line number Diff line number Diff line
@@ -549,13 +549,39 @@ class SemanticTransformer(nn.Module):
        **kwargs
    ):
        super().__init__()
        self.start_token = nn.Parameter(torch.randn(dim))

        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)

        self.transformer = Transformer(dim = dim, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens)

    def forward(
        self,
        semantic_token_ids
        ids,
        return_loss = False
    ):
        raise NotImplemented
        if return_loss:
            labels, ids = ids.clone(), ids[:, :-1]

        tokens = self.semantic_embedding(ids)

        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = ids.shape[0])

        tokens = torch.cat((start_tokens, tokens), dim = 1)

        tokens = self.transformer(tokens)
        logits = self.to_logits(tokens)

        if not return_loss:
            return logits

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels
        )

        return loss

class CoarseTransformer(nn.Module):
    def __init__(