Loading audiolm_pytorch/audiolm_pytorch.py +37 −4 Original line number Diff line number Diff line Loading @@ -16,6 +16,9 @@ from vector_quantize_pytorch import ResidualVQ def exists(val): return val is not None def round_down_nearest_multiple(val, mult): return (val // mult) * mult # gan losses def hinge_discr_loss(fake, real): Loading Loading @@ -593,8 +596,11 @@ class FineTransformer(nn.Module): self.transformer = Transformer(dim = dim, **kwargs) self.coarse_logits = nn.Linear(dim, codebook_size) self.fine_logits = nn.Linear(dim, codebook_size) self.num_coarse_quantizers = num_coarse_quantizers self.num_fine_quantizers = num_fine_quantizers self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim)) self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size, dim)) def forward( self, Loading @@ -616,8 +622,35 @@ class FineTransformer(nn.Module): pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:] coarse_logits = self.coarse_logits(pred_coarse_tokens) fine_logits = self.fine_logits(pred_fine_tokens) # get coarse logits pred_coarse_tokens = rearrange(pred_coarse_tokens, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers) coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens) coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c') # get fine logits n = pred_fine_tokens.shape[1] nq = round_down_nearest_multiple(n, self.num_fine_quantizers) pred_fine_tokens_groupable, pred_fine_tokens_remainder = pred_fine_tokens[:, :nq], pred_fine_tokens[:, nq:] pred_fine_tokens_groupable = rearrange(pred_fine_tokens_groupable, 'b (n q) d -> b n q d', q = self.num_fine_quantizers) fine_logits_groupable = einsum('q c d, b n q d -> b n q c', self.fine_logit_weights, pred_fine_tokens_groupable) fine_logits_groupable = rearrange(fine_logits_groupable, 'b n q c -> b (n q) c') remainder_num_quantizers = pred_fine_tokens_remainder.shape[1] if remainder_num_quantizers > 0: fine_logits_remainder = einsum('q c d, b n q d -> b n q c', self.fine_logit_weights[:remainder_num_quantizers], pred_fine_tokens_remainder) fine_logits = torch.cat((fine_logits_groupable, fine_logits_remainder), dim = 1) else: fine_logits = fine_logits_groupable return coarse_logits, fine_logits Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -18,7 +18,7 @@ setup( 'audio generation' ], install_requires=[ 'einops>=0.4', 'einops>=0.5', 'ema-pytorch', 'torch>=1.6', 'vector-quantize-pytorch>=0.10.5' Loading Loading
audiolm_pytorch/audiolm_pytorch.py +37 −4 Original line number Diff line number Diff line Loading @@ -16,6 +16,9 @@ from vector_quantize_pytorch import ResidualVQ def exists(val): return val is not None def round_down_nearest_multiple(val, mult): return (val // mult) * mult # gan losses def hinge_discr_loss(fake, real): Loading Loading @@ -593,8 +596,11 @@ class FineTransformer(nn.Module): self.transformer = Transformer(dim = dim, **kwargs) self.coarse_logits = nn.Linear(dim, codebook_size) self.fine_logits = nn.Linear(dim, codebook_size) self.num_coarse_quantizers = num_coarse_quantizers self.num_fine_quantizers = num_fine_quantizers self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim)) self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size, dim)) def forward( self, Loading @@ -616,8 +622,35 @@ class FineTransformer(nn.Module): pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:] coarse_logits = self.coarse_logits(pred_coarse_tokens) fine_logits = self.fine_logits(pred_fine_tokens) # get coarse logits pred_coarse_tokens = rearrange(pred_coarse_tokens, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers) coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens) coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c') # get fine logits n = pred_fine_tokens.shape[1] nq = round_down_nearest_multiple(n, self.num_fine_quantizers) pred_fine_tokens_groupable, pred_fine_tokens_remainder = pred_fine_tokens[:, :nq], pred_fine_tokens[:, nq:] pred_fine_tokens_groupable = rearrange(pred_fine_tokens_groupable, 'b (n q) d -> b n q d', q = self.num_fine_quantizers) fine_logits_groupable = einsum('q c d, b n q d -> b n q c', self.fine_logit_weights, pred_fine_tokens_groupable) fine_logits_groupable = rearrange(fine_logits_groupable, 'b n q c -> b (n q) c') remainder_num_quantizers = pred_fine_tokens_remainder.shape[1] if remainder_num_quantizers > 0: fine_logits_remainder = einsum('q c d, b n q d -> b n q c', self.fine_logit_weights[:remainder_num_quantizers], pred_fine_tokens_remainder) fine_logits = torch.cat((fine_logits_groupable, fine_logits_remainder), dim = 1) else: fine_logits = fine_logits_groupable return coarse_logits, fine_logits Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -18,7 +18,7 @@ setup( 'audio generation' ], install_requires=[ 'einops>=0.4', 'einops>=0.5', 'ema-pytorch', 'torch>=1.6', 'vector-quantize-pytorch>=0.10.5' Loading