Loading audiolm_pytorch/audiolm_pytorch.py +33 −14 Original line number Diff line number Diff line import math from functools import partial from functools import partial, wraps from beartype.typing import Optional, Union, List from beartype import beartype Loading Loading @@ -32,6 +32,14 @@ def exists(val): def default(val, d): return val if exists(val) else d def maybe(fn): @wraps(fn) def inner(x, *args, **kwargs): if not exists(x): return x return fn(x, *args, **kwargs) return inner def ceil_div(numer, denom): return (numer + denom - 1) // denom Loading Loading @@ -551,6 +559,7 @@ class CoarseTransformer(nn.Module): cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, project_semantic_logits = True, **kwargs ): super().__init__() Loading Loading @@ -588,7 +597,7 @@ class CoarseTransformer(nn.Module): self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) if project_semantic_logits else None self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) @property Loading Loading @@ -675,7 +684,7 @@ class CoarseTransformer(nn.Module): # semantic logits semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits else None semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits and exists(self.to_semantic_logits) else None # get coarse logits Loading Loading @@ -718,6 +727,7 @@ class FineTransformer(nn.Module): cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, project_coarse_logits = True, **kwargs ): super().__init__() Loading Loading @@ -756,7 +766,7 @@ class FineTransformer(nn.Module): self.num_coarse_quantizers = num_coarse_quantizers self.num_fine_quantizers = num_fine_quantizers self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) if project_coarse_logits else None self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim)) @property Loading Loading @@ -856,7 +866,7 @@ class FineTransformer(nn.Module): coarse_logits = None if not return_only_fine_logits: if not return_only_fine_logits and exists(self.coarse_logit_weights): coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens) coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c') Loading Loading @@ -1218,7 +1228,7 @@ class CoarseTransformerWrapper(nn.Module): if not return_loss: return semantic_logits, coarse_logits coarse_logits, semantic_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, semantic_logits)) coarse_logits, semantic_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, semantic_logits)) if self.unique_consecutive: num_coarse_logits, num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum() Loading @@ -1226,7 +1236,7 @@ class CoarseTransformerWrapper(nn.Module): num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] semantic_loss = 0. if self.semantic_cross_entropy_loss_weight > 0: if self.semantic_cross_entropy_loss_weight > 0 and exists(semantic_logits): semantic_loss = F.cross_entropy( semantic_logits, semantic_labels, Loading Loading @@ -1424,12 +1434,16 @@ class FineTransformerWrapper(nn.Module): if not return_loss: return coarse_logits, fine_logits coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits)) coarse_logits, fine_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, fine_logits)) num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1] num_fine_logits = fine_logits.shape[-1] num_coarse_logits = 0 coarse_loss = 0. if self.coarse_cross_entropy_loss_weight > 0: if self.coarse_cross_entropy_loss_weight > 0 and exists(coarse_logits): num_coarse_logits = coarse_logits.shape[-1] coarse_loss = F.cross_entropy( coarse_logits, coarse_labels Loading Loading @@ -1499,25 +1513,30 @@ class AudioLM(nn.Module): *, batch_size = 1, text: Optional[List[str]] = None, text_embeds: Optional[torch.Tensor] = None, prime_wave = None, max_length = 2048, return_coarse_generated_wave = False, mask_out_generated_fine_tokens = False ): assert not (self.needs_text and not exists(text)), 'text needs to be passed in if one of the transformer requires conditioning' assert not (self.needs_text and (not exists(text) and not exists(text_embeds))), 'text needs to be passed in if one of the transformer requires conditioning' if self.needs_text: if exists(text): text_embeds = self.semantic.embed_text(texts) if exists(prime_wave): prime_wave = prime_wave.to(self.device) semantic_token_ids = self.semantic.generate( text = text if self.semantic_has_condition else None, text_embeds = text_embeds if self.semantic_has_condition else None, batch_size = batch_size, prime_wave = prime_wave, max_length = max_length ) coarse_token_ids_or_recon_wave = self.coarse.generate( text = text if self.coarse_has_condition else None, text_embeds = text_embeds if self.coarse_has_condition else None, semantic_token_ids = semantic_token_ids, reconstruct_wave = return_coarse_generated_wave ) Loading @@ -1526,7 +1545,7 @@ class AudioLM(nn.Module): return coarse_token_ids_or_recon_wave generated_wave = self.fine.generate( text = text if self.fine_has_condition else None, text_embeds = text_embeds if self.fine_has_condition else None, coarse_token_ids = coarse_token_ids_or_recon_wave, reconstruct_wave = True, mask_out_generated_fine_tokens = mask_out_generated_fine_tokens Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.7.9', version = '0.8.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/audiolm_pytorch.py +33 −14 Original line number Diff line number Diff line import math from functools import partial from functools import partial, wraps from beartype.typing import Optional, Union, List from beartype import beartype Loading Loading @@ -32,6 +32,14 @@ def exists(val): def default(val, d): return val if exists(val) else d def maybe(fn): @wraps(fn) def inner(x, *args, **kwargs): if not exists(x): return x return fn(x, *args, **kwargs) return inner def ceil_div(numer, denom): return (numer + denom - 1) // denom Loading Loading @@ -551,6 +559,7 @@ class CoarseTransformer(nn.Module): cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, project_semantic_logits = True, **kwargs ): super().__init__() Loading Loading @@ -588,7 +597,7 @@ class CoarseTransformer(nn.Module): self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) if project_semantic_logits else None self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) @property Loading Loading @@ -675,7 +684,7 @@ class CoarseTransformer(nn.Module): # semantic logits semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits else None semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits and exists(self.to_semantic_logits) else None # get coarse logits Loading Loading @@ -718,6 +727,7 @@ class FineTransformer(nn.Module): cond_as_self_attn_prefix = False, cond_drop_prob = 0.5, grad_shrink_alpha = 0.1, project_coarse_logits = True, **kwargs ): super().__init__() Loading Loading @@ -756,7 +766,7 @@ class FineTransformer(nn.Module): self.num_coarse_quantizers = num_coarse_quantizers self.num_fine_quantizers = num_fine_quantizers self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) if project_coarse_logits else None self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim)) @property Loading Loading @@ -856,7 +866,7 @@ class FineTransformer(nn.Module): coarse_logits = None if not return_only_fine_logits: if not return_only_fine_logits and exists(self.coarse_logit_weights): coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens) coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c') Loading Loading @@ -1218,7 +1228,7 @@ class CoarseTransformerWrapper(nn.Module): if not return_loss: return semantic_logits, coarse_logits coarse_logits, semantic_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, semantic_logits)) coarse_logits, semantic_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, semantic_logits)) if self.unique_consecutive: num_coarse_logits, num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum() Loading @@ -1226,7 +1236,7 @@ class CoarseTransformerWrapper(nn.Module): num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] semantic_loss = 0. if self.semantic_cross_entropy_loss_weight > 0: if self.semantic_cross_entropy_loss_weight > 0 and exists(semantic_logits): semantic_loss = F.cross_entropy( semantic_logits, semantic_labels, Loading Loading @@ -1424,12 +1434,16 @@ class FineTransformerWrapper(nn.Module): if not return_loss: return coarse_logits, fine_logits coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits)) coarse_logits, fine_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, fine_logits)) num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1] num_fine_logits = fine_logits.shape[-1] num_coarse_logits = 0 coarse_loss = 0. if self.coarse_cross_entropy_loss_weight > 0: if self.coarse_cross_entropy_loss_weight > 0 and exists(coarse_logits): num_coarse_logits = coarse_logits.shape[-1] coarse_loss = F.cross_entropy( coarse_logits, coarse_labels Loading Loading @@ -1499,25 +1513,30 @@ class AudioLM(nn.Module): *, batch_size = 1, text: Optional[List[str]] = None, text_embeds: Optional[torch.Tensor] = None, prime_wave = None, max_length = 2048, return_coarse_generated_wave = False, mask_out_generated_fine_tokens = False ): assert not (self.needs_text and not exists(text)), 'text needs to be passed in if one of the transformer requires conditioning' assert not (self.needs_text and (not exists(text) and not exists(text_embeds))), 'text needs to be passed in if one of the transformer requires conditioning' if self.needs_text: if exists(text): text_embeds = self.semantic.embed_text(texts) if exists(prime_wave): prime_wave = prime_wave.to(self.device) semantic_token_ids = self.semantic.generate( text = text if self.semantic_has_condition else None, text_embeds = text_embeds if self.semantic_has_condition else None, batch_size = batch_size, prime_wave = prime_wave, max_length = max_length ) coarse_token_ids_or_recon_wave = self.coarse.generate( text = text if self.coarse_has_condition else None, text_embeds = text_embeds if self.coarse_has_condition else None, semantic_token_ids = semantic_token_ids, reconstruct_wave = return_coarse_generated_wave ) Loading @@ -1526,7 +1545,7 @@ class AudioLM(nn.Module): return coarse_token_ids_or_recon_wave generated_wave = self.fine.generate( text = text if self.fine_has_condition else None, text_embeds = text_embeds if self.fine_has_condition else None, coarse_token_ids = coarse_token_ids_or_recon_wave, reconstruct_wave = True, mask_out_generated_fine_tokens = mask_out_generated_fine_tokens Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.7.9', version = '0.8.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading