Commit 0ec7667b authored by Phil Wang's avatar Phil Wang
Browse files

handle projection of fine and coarse logits correctly in the final transformer in the hierarchy

parent a9efd2d9
Loading
Loading
Loading
Loading
+37 −4
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@ from vector_quantize_pytorch import ResidualVQ
def exists(val):
    return val is not None

def round_down_nearest_multiple(val, mult):
    return (val // mult) * mult

# gan losses

def hinge_discr_loss(fake, real):
@@ -593,8 +596,11 @@ class FineTransformer(nn.Module):

        self.transformer = Transformer(dim = dim, **kwargs)

        self.coarse_logits = nn.Linear(dim, codebook_size)
        self.fine_logits = nn.Linear(dim, codebook_size)
        self.num_coarse_quantizers = num_coarse_quantizers
        self.num_fine_quantizers = num_fine_quantizers

        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim))
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size, dim))

    def forward(
        self,
@@ -616,8 +622,35 @@ class FineTransformer(nn.Module):

        pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:]

        coarse_logits = self.coarse_logits(pred_coarse_tokens)
        fine_logits = self.fine_logits(pred_fine_tokens)
        # get coarse logits

        pred_coarse_tokens = rearrange(pred_coarse_tokens, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers)

        coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens)

        coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c')

        # get fine logits

        n = pred_fine_tokens.shape[1]
        nq = round_down_nearest_multiple(n, self.num_fine_quantizers)

        pred_fine_tokens_groupable, pred_fine_tokens_remainder = pred_fine_tokens[:, :nq], pred_fine_tokens[:, nq:]

        pred_fine_tokens_groupable = rearrange(pred_fine_tokens_groupable, 'b (n q) d -> b n q d', q = self.num_fine_quantizers)

        fine_logits_groupable = einsum('q c d, b n q d -> b n q c', self.fine_logit_weights, pred_fine_tokens_groupable)

        fine_logits_groupable = rearrange(fine_logits_groupable, 'b n q c -> b (n q) c')

        remainder_num_quantizers = pred_fine_tokens_remainder.shape[1]

        if remainder_num_quantizers > 0:
            fine_logits_remainder = einsum('q c d, b n q d -> b n q c', self.fine_logit_weights[:remainder_num_quantizers], pred_fine_tokens_remainder)

            fine_logits = torch.cat((fine_logits_groupable, fine_logits_remainder), dim = 1)
        else:
            fine_logits = fine_logits_groupable

        return coarse_logits, fine_logits

+1 −1
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@ setup(
    'audio generation'
  ],
  install_requires=[
    'einops>=0.4',
    'einops>=0.5',
    'ema-pytorch',
    'torch>=1.6',
    'vector-quantize-pytorch>=0.10.5'