Loading audiolm_pytorch/audiolm_pytorch.py +23 −7 Original line number Diff line number Diff line Loading @@ -683,7 +683,8 @@ class FineTransformer(nn.Module): self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob self.start_token = nn.Parameter(torch.randn(dim)) self.coarse_start_token = nn.Parameter(torch.randn(dim)) self.fine_start_token = nn.Parameter(torch.randn(dim)) codebook_size_with_eos = codebook_size + 1 Loading Loading @@ -725,7 +726,8 @@ class FineTransformer(nn.Module): fine_token_ids, text = None, text_embeds = None, cond_drop_prob = None cond_drop_prob = None, self_attn_mask = None ): b, device = coarse_token_ids.shape[0], coarse_token_ids.device has_text = exists(text) or exists(text_embeds) Loading Loading @@ -760,13 +762,19 @@ class FineTransformer(nn.Module): coarse_tokens = self.coarse_embedding(coarse_token_ids) fine_tokens = self.fine_embedding(fine_token_ids) start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b) coarse_start_tokens = repeat(self.coarse_start_token, 'd -> b 1 d', b = b) fine_start_tokens = repeat(self.fine_start_token, 'd -> b 1 d', b = b) tokens = torch.cat((start_tokens, coarse_tokens, fine_tokens), dim = 1) tokens = torch.cat(( coarse_start_tokens, coarse_tokens, fine_start_tokens, fine_tokens ), dim = 1) tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask) tokens = self.transformer(tokens, context = text_embeds, self_attn_mask =self_attn_mask, context_mask = text_mask) pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:] pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, (n + 1):] # get coarse logits Loading Loading @@ -916,7 +924,8 @@ class FineTransformerWrapper(nn.Module): *, transformer: FineTransformer, soundstream: Optional[SoundStream] = None, num_coarse_quantize = 3 num_coarse_quantize = 3, pad_id = -1 ): super().__init__() self.soundstream = soundstream Loading @@ -924,6 +933,7 @@ class FineTransformerWrapper(nn.Module): assert num_coarse_quantize > 0 self.num_coarse_quantize = num_coarse_quantize self.pad_id = pad_id def forward( self, Loading Loading @@ -955,9 +965,15 @@ class FineTransformerWrapper(nn.Module): coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone() fine_token_ids = fine_token_ids[:, :-1] self_attn_mask = coarse_token_ids != self.pad_id fine_token_seq_len = fine_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, fine_token_seq_len + 1), value = True) coarse_logits, fine_logits = self.transformer( coarse_token_ids = coarse_token_ids, fine_token_ids = fine_token_ids, self_attn_mask = self_attn_mask, **kwargs ) Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.34', version = '0.0.36', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/audiolm_pytorch.py +23 −7 Original line number Diff line number Diff line Loading @@ -683,7 +683,8 @@ class FineTransformer(nn.Module): self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob self.start_token = nn.Parameter(torch.randn(dim)) self.coarse_start_token = nn.Parameter(torch.randn(dim)) self.fine_start_token = nn.Parameter(torch.randn(dim)) codebook_size_with_eos = codebook_size + 1 Loading Loading @@ -725,7 +726,8 @@ class FineTransformer(nn.Module): fine_token_ids, text = None, text_embeds = None, cond_drop_prob = None cond_drop_prob = None, self_attn_mask = None ): b, device = coarse_token_ids.shape[0], coarse_token_ids.device has_text = exists(text) or exists(text_embeds) Loading Loading @@ -760,13 +762,19 @@ class FineTransformer(nn.Module): coarse_tokens = self.coarse_embedding(coarse_token_ids) fine_tokens = self.fine_embedding(fine_token_ids) start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b) coarse_start_tokens = repeat(self.coarse_start_token, 'd -> b 1 d', b = b) fine_start_tokens = repeat(self.fine_start_token, 'd -> b 1 d', b = b) tokens = torch.cat((start_tokens, coarse_tokens, fine_tokens), dim = 1) tokens = torch.cat(( coarse_start_tokens, coarse_tokens, fine_start_tokens, fine_tokens ), dim = 1) tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask) tokens = self.transformer(tokens, context = text_embeds, self_attn_mask =self_attn_mask, context_mask = text_mask) pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:] pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, (n + 1):] # get coarse logits Loading Loading @@ -916,7 +924,8 @@ class FineTransformerWrapper(nn.Module): *, transformer: FineTransformer, soundstream: Optional[SoundStream] = None, num_coarse_quantize = 3 num_coarse_quantize = 3, pad_id = -1 ): super().__init__() self.soundstream = soundstream Loading @@ -924,6 +933,7 @@ class FineTransformerWrapper(nn.Module): assert num_coarse_quantize > 0 self.num_coarse_quantize = num_coarse_quantize self.pad_id = pad_id def forward( self, Loading Loading @@ -955,9 +965,15 @@ class FineTransformerWrapper(nn.Module): coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone() fine_token_ids = fine_token_ids[:, :-1] self_attn_mask = coarse_token_ids != self.pad_id fine_token_seq_len = fine_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, fine_token_seq_len + 1), value = True) coarse_logits, fine_logits = self.transformer( coarse_token_ids = coarse_token_ids, fine_token_ids = fine_token_ids, self_attn_mask = self_attn_mask, **kwargs ) Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.34', version = '0.0.36', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading