Loading README.md +10 −0 Original line number Diff line number Diff line Loading @@ -275,3 +275,13 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d primaryClass = {cs.CV} } ``` ```bibtex @article{Liu2022FCMFC, title = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners}, author = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel}, journal = {ArXiv}, year = {2022}, volume = {abs/2210.13432} } ``` audiolm_pytorch/audiolm_pytorch.py +92 −23 Original line number Diff line number Diff line Loading @@ -49,6 +49,17 @@ def eval_decorator(fn): return out return inner # tensor helpers def generate_mask_with_prob(shape, mask_prob, device): seq = shape[-1] rand = torch.randn(shape, device = device) rand[:, 0] = -torch.finfo(rand.dtype).max num_mask = min(int(seq * mask_prob), seq - 1) indices = rand.topk(num_mask, dim = -1).indices mask = ~torch.zeros(shape, device = device).scatter(1, indices, 1.).bool() return mask # attention related utils def grad_shrink(t, alpha = 0.1): Loading Loading @@ -291,8 +302,8 @@ class Transformer(nn.Module): heads, dim_context = None, cross_attend = False, attn_dropout = 0.1, ff_dropout = 0.1, attn_dropout = 0., ff_dropout = 0., grad_shrink_alpha = 0.1, **kwargs ): Loading Loading @@ -346,10 +357,10 @@ class SemanticTransformer(nn.Module): *, dim, depth, heads, num_semantic_tokens, attn_dropout = 0.1, ff_dropout = 0.1, heads = 8, attn_dropout = 0., ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -370,7 +381,8 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id self.transformer = Transformer(dim = dim, self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, Loading @@ -378,7 +390,9 @@ class SemanticTransformer(nn.Module): dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) **kwargs ) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property Loading Loading @@ -406,6 +420,7 @@ class SemanticTransformer(nn.Module): return_loss = False, text: Optional[List[str]] = None, text_embeds = None, self_attn_mask = None, cond_drop_prob = None, unique_consecutive = None ): Loading Loading @@ -437,7 +452,10 @@ class SemanticTransformer(nn.Module): tokens = torch.cat((start_tokens, tokens), dim = 1) tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask) if exists(self_attn_mask): self_attn_mask = F.pad(self_attn_mask, (1, 0), value = True) tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask) return self.to_logits(tokens) @beartype Loading @@ -449,8 +467,10 @@ class CoarseTransformer(nn.Module): num_coarse_quantizers, dim, depth, heads, num_semantic_tokens, heads = 8, attn_dropout = 0., ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -472,7 +492,17 @@ class CoarseTransformer(nn.Module): codebook_size_with_eos = codebook_size + 1 self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading Loading @@ -597,7 +627,9 @@ class FineTransformer(nn.Module): codebook_size, dim, depth, heads, heads = 8, attn_dropout = 0., ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -621,7 +653,17 @@ class FineTransformer(nn.Module): self.eos_id = codebook_size self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading Loading @@ -765,7 +807,8 @@ class SemanticTransformerWrapper(nn.Module): transformer: SemanticTransformer, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, pad_id = -1, unique_consecutive = True unique_consecutive = True, mask_prob = 0.15 ): super().__init__() self.wav2vec = wav2vec Loading @@ -775,6 +818,7 @@ class SemanticTransformerWrapper(nn.Module): self.unique_consecutive = unique_consecutive self.pad_id = pad_id self.eos_id = transformer.eos_id self.mask_prob = mask_prob @property def device(self): Loading Loading @@ -888,10 +932,15 @@ class SemanticTransformerWrapper(nn.Module): if return_loss: input_ids = semantic_token_ids[:, :-1] self_attn_mask = None if self.mask_prob > 0.: self_attn_mask = generate_mask_with_prob(input_ids.shape, self.mask_prob, input_ids.device) logits = self.transformer( ids = input_ids, text = text, text_embeds = text_embeds, self_attn_mask = self_attn_mask, **kwargs ) Loading @@ -916,7 +965,8 @@ class CoarseTransformerWrapper(nn.Module): wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, pad_id = -1, unique_consecutive = True, semantic_cross_entropy_loss_weight = 1. semantic_cross_entropy_loss_weight = 1., mask_prob = 0.15 ): super().__init__() self.soundstream = soundstream Loading @@ -932,6 +982,8 @@ class CoarseTransformerWrapper(nn.Module): self.semantic_eos_id = transformer.semantic_eos_id self.coarse_eos_id = transformer.coarse_eos_id self.mask_prob = mask_prob @property def device(self): return next(self.parameters()).device Loading Loading @@ -1064,6 +1116,13 @@ class CoarseTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # whether to early return the logits if not return_loss: return semantic_logits, coarse_logits Loading Loading @@ -1100,7 +1159,8 @@ class FineTransformerWrapper(nn.Module): transformer: FineTransformer, soundstream: Optional[SoundStream] = None, coarse_cross_entropy_loss_weight = 1., pad_id = -1 pad_id = -1, mask_prob = 0.15 ): super().__init__() self.soundstream = soundstream Loading @@ -1115,6 +1175,8 @@ class FineTransformerWrapper(nn.Module): self.pad_id = pad_id self.coarse_cross_entropy_loss_weight = coarse_cross_entropy_loss_weight self.mask_prob = mask_prob @property def device(self): return next(self.parameters()).device Loading Loading @@ -1256,6 +1318,13 @@ class FineTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # early return the logits if not return_loss: return coarse_logits, fine_logits Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.66', version = '0.0.67', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
README.md +10 −0 Original line number Diff line number Diff line Loading @@ -275,3 +275,13 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d primaryClass = {cs.CV} } ``` ```bibtex @article{Liu2022FCMFC, title = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners}, author = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel}, journal = {ArXiv}, year = {2022}, volume = {abs/2210.13432} } ```
audiolm_pytorch/audiolm_pytorch.py +92 −23 Original line number Diff line number Diff line Loading @@ -49,6 +49,17 @@ def eval_decorator(fn): return out return inner # tensor helpers def generate_mask_with_prob(shape, mask_prob, device): seq = shape[-1] rand = torch.randn(shape, device = device) rand[:, 0] = -torch.finfo(rand.dtype).max num_mask = min(int(seq * mask_prob), seq - 1) indices = rand.topk(num_mask, dim = -1).indices mask = ~torch.zeros(shape, device = device).scatter(1, indices, 1.).bool() return mask # attention related utils def grad_shrink(t, alpha = 0.1): Loading Loading @@ -291,8 +302,8 @@ class Transformer(nn.Module): heads, dim_context = None, cross_attend = False, attn_dropout = 0.1, ff_dropout = 0.1, attn_dropout = 0., ff_dropout = 0., grad_shrink_alpha = 0.1, **kwargs ): Loading Loading @@ -346,10 +357,10 @@ class SemanticTransformer(nn.Module): *, dim, depth, heads, num_semantic_tokens, attn_dropout = 0.1, ff_dropout = 0.1, heads = 8, attn_dropout = 0., ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -370,7 +381,8 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id self.transformer = Transformer(dim = dim, self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, Loading @@ -378,7 +390,9 @@ class SemanticTransformer(nn.Module): dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) **kwargs ) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property Loading Loading @@ -406,6 +420,7 @@ class SemanticTransformer(nn.Module): return_loss = False, text: Optional[List[str]] = None, text_embeds = None, self_attn_mask = None, cond_drop_prob = None, unique_consecutive = None ): Loading Loading @@ -437,7 +452,10 @@ class SemanticTransformer(nn.Module): tokens = torch.cat((start_tokens, tokens), dim = 1) tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask) if exists(self_attn_mask): self_attn_mask = F.pad(self_attn_mask, (1, 0), value = True) tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask) return self.to_logits(tokens) @beartype Loading @@ -449,8 +467,10 @@ class CoarseTransformer(nn.Module): num_coarse_quantizers, dim, depth, heads, num_semantic_tokens, heads = 8, attn_dropout = 0., ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -472,7 +492,17 @@ class CoarseTransformer(nn.Module): codebook_size_with_eos = codebook_size + 1 self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading Loading @@ -597,7 +627,9 @@ class FineTransformer(nn.Module): codebook_size, dim, depth, heads, heads = 8, attn_dropout = 0., ff_dropout = 0., t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -621,7 +653,17 @@ class FineTransformer(nn.Module): self.eos_id = codebook_size self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer( dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs ) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading Loading @@ -765,7 +807,8 @@ class SemanticTransformerWrapper(nn.Module): transformer: SemanticTransformer, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, pad_id = -1, unique_consecutive = True unique_consecutive = True, mask_prob = 0.15 ): super().__init__() self.wav2vec = wav2vec Loading @@ -775,6 +818,7 @@ class SemanticTransformerWrapper(nn.Module): self.unique_consecutive = unique_consecutive self.pad_id = pad_id self.eos_id = transformer.eos_id self.mask_prob = mask_prob @property def device(self): Loading Loading @@ -888,10 +932,15 @@ class SemanticTransformerWrapper(nn.Module): if return_loss: input_ids = semantic_token_ids[:, :-1] self_attn_mask = None if self.mask_prob > 0.: self_attn_mask = generate_mask_with_prob(input_ids.shape, self.mask_prob, input_ids.device) logits = self.transformer( ids = input_ids, text = text, text_embeds = text_embeds, self_attn_mask = self_attn_mask, **kwargs ) Loading @@ -916,7 +965,8 @@ class CoarseTransformerWrapper(nn.Module): wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, pad_id = -1, unique_consecutive = True, semantic_cross_entropy_loss_weight = 1. semantic_cross_entropy_loss_weight = 1., mask_prob = 0.15 ): super().__init__() self.soundstream = soundstream Loading @@ -932,6 +982,8 @@ class CoarseTransformerWrapper(nn.Module): self.semantic_eos_id = transformer.semantic_eos_id self.coarse_eos_id = transformer.coarse_eos_id self.mask_prob = mask_prob @property def device(self): return next(self.parameters()).device Loading Loading @@ -1064,6 +1116,13 @@ class CoarseTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # whether to early return the logits if not return_loss: return semantic_logits, coarse_logits Loading Loading @@ -1100,7 +1159,8 @@ class FineTransformerWrapper(nn.Module): transformer: FineTransformer, soundstream: Optional[SoundStream] = None, coarse_cross_entropy_loss_weight = 1., pad_id = -1 pad_id = -1, mask_prob = 0.15 ): super().__init__() self.soundstream = soundstream Loading @@ -1115,6 +1175,8 @@ class FineTransformerWrapper(nn.Module): self.pad_id = pad_id self.coarse_cross_entropy_loss_weight = coarse_cross_entropy_loss_weight self.mask_prob = mask_prob @property def device(self): return next(self.parameters()).device Loading Loading @@ -1256,6 +1318,13 @@ class FineTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # early return the logits if not return_loss: return coarse_logits, fine_logits Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.66', version = '0.0.67', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading