Loading audiolm_pytorch/audiolm_pytorch.py +12 −14 Original line number Diff line number Diff line Loading @@ -544,7 +544,8 @@ class CoarseTransformer(nn.Module): self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob self.start_token = nn.Parameter(torch.randn(dim)) self.semantic_start_token = nn.Parameter(torch.randn(dim)) self.coarse_start_token = nn.Parameter(torch.randn(dim)) self.semantic_eos_id = num_semantic_tokens self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim) Loading Loading @@ -621,24 +622,20 @@ class CoarseTransformer(nn.Module): semantic_seq_len = semantic_tokens.shape[1] start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b) semantic_start_tokens = repeat(self.semantic_start_token, 'd -> b 1 d', b = b) coarse_start_tokens = repeat(self.coarse_start_token, 'd -> b 1 d', b = b) tokens = torch.cat((start_tokens, semantic_tokens, coarse_tokens), dim = 1) tokens = torch.cat(( semantic_start_tokens, semantic_tokens, coarse_start_tokens, coarse_tokens ), dim = 1) tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask) pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):] # get the eos token from predicted semantic tokens, and use that to predict the first coarse token semantic_eos = semantic_token_ids == self.semantic_eos_id pred_semantic_eos_tokens = tokens[:, 1:(semantic_seq_len + 1)][semantic_eos] pred_coarse_tokens = torch.cat(( rearrange(pred_semantic_eos_tokens, 'b d -> b 1 d'), pred_coarse_tokens), dim = 1) # semantic logits semantic_logits = self.to_semantic_logits(pred_semantic_tokens) Loading Loading @@ -880,7 +877,8 @@ class CoarseTransformerWrapper(nn.Module): if self.unique_consecutive: self_attn_mask = semantic_token_ids != self.pad_id semantic_token_ids = semantic_token_ids.masked_fill(~self_attn_mask, 0) self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_ids.shape[-1]), value = True) coarse_token_len = coarse_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True) semantic_logits, coarse_logits = self.transformer( semantic_token_ids = semantic_token_ids, Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.33', version = '0.0.34', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/audiolm_pytorch.py +12 −14 Original line number Diff line number Diff line Loading @@ -544,7 +544,8 @@ class CoarseTransformer(nn.Module): self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob self.start_token = nn.Parameter(torch.randn(dim)) self.semantic_start_token = nn.Parameter(torch.randn(dim)) self.coarse_start_token = nn.Parameter(torch.randn(dim)) self.semantic_eos_id = num_semantic_tokens self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim) Loading Loading @@ -621,24 +622,20 @@ class CoarseTransformer(nn.Module): semantic_seq_len = semantic_tokens.shape[1] start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b) semantic_start_tokens = repeat(self.semantic_start_token, 'd -> b 1 d', b = b) coarse_start_tokens = repeat(self.coarse_start_token, 'd -> b 1 d', b = b) tokens = torch.cat((start_tokens, semantic_tokens, coarse_tokens), dim = 1) tokens = torch.cat(( semantic_start_tokens, semantic_tokens, coarse_start_tokens, coarse_tokens ), dim = 1) tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask) pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):] # get the eos token from predicted semantic tokens, and use that to predict the first coarse token semantic_eos = semantic_token_ids == self.semantic_eos_id pred_semantic_eos_tokens = tokens[:, 1:(semantic_seq_len + 1)][semantic_eos] pred_coarse_tokens = torch.cat(( rearrange(pred_semantic_eos_tokens, 'b d -> b 1 d'), pred_coarse_tokens), dim = 1) # semantic logits semantic_logits = self.to_semantic_logits(pred_semantic_tokens) Loading Loading @@ -880,7 +877,8 @@ class CoarseTransformerWrapper(nn.Module): if self.unique_consecutive: self_attn_mask = semantic_token_ids != self.pad_id semantic_token_ids = semantic_token_ids.masked_fill(~self_attn_mask, 0) self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_ids.shape[-1]), value = True) coarse_token_len = coarse_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True) semantic_logits, coarse_logits = self.transformer( semantic_token_ids = semantic_token_ids, Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.33', version = '0.0.34', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading