Loading audiolm_pytorch/audiolm_pytorch.py +17 −3 Original line number Diff line number Diff line Loading @@ -353,6 +353,8 @@ class SemanticTransformer(nn.Module): self, *, max_length, text = None, text_embeds = None, prime_wave = None, prime_ids = None, batch_size = 1, Loading @@ -377,6 +379,15 @@ class SemanticTransformer(nn.Module): if self.unique_consecutive: ids = batch_unique_consecutive(ids, pad_value = self.pad_id) # derive text embeddings if needed has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) # start length and get running id output start_length = ids.shape[-1] Loading @@ -388,6 +399,7 @@ class SemanticTransformer(nn.Module): logits = self.forward_with_cond_scale( ids = output, text_embeds = text_embeds, **kwargs ) Loading @@ -401,7 +413,7 @@ class SemanticTransformer(nn.Module): if all_rows_have_eos_id(output, self.eos_id): break output = mask_out_after_eos_id(output, self.eos_id) output = mask_out_after_eos_id(output, self.pad_id) return output def forward_with_cond_scale( Loading Loading @@ -559,10 +571,12 @@ class CoarseTransformer(nn.Module): has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = None if exists(text_embeds): text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.29', version = '0.0.30', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/audiolm_pytorch.py +17 −3 Original line number Diff line number Diff line Loading @@ -353,6 +353,8 @@ class SemanticTransformer(nn.Module): self, *, max_length, text = None, text_embeds = None, prime_wave = None, prime_ids = None, batch_size = 1, Loading @@ -377,6 +379,15 @@ class SemanticTransformer(nn.Module): if self.unique_consecutive: ids = batch_unique_consecutive(ids, pad_value = self.pad_id) # derive text embeddings if needed has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) # start length and get running id output start_length = ids.shape[-1] Loading @@ -388,6 +399,7 @@ class SemanticTransformer(nn.Module): logits = self.forward_with_cond_scale( ids = output, text_embeds = text_embeds, **kwargs ) Loading @@ -401,7 +413,7 @@ class SemanticTransformer(nn.Module): if all_rows_have_eos_id(output, self.eos_id): break output = mask_out_after_eos_id(output, self.eos_id) output = mask_out_after_eos_id(output, self.pad_id) return output def forward_with_cond_scale( Loading Loading @@ -559,10 +571,12 @@ class CoarseTransformer(nn.Module): has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = None if exists(text_embeds): text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.29', version = '0.0.30', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading