Loading audiolm_pytorch/audiolm_pytorch.py +2 −1 Original line number Diff line number Diff line Loading @@ -1016,6 +1016,7 @@ class SemanticTransformerWrapper(nn.Module): logits = self.transformer.forward_with_cond_scale( ids = sample_semantic_ids, text_embeds = text_embeds, cond_scale = cond_scale, **kwargs ) Loading Loading @@ -1600,7 +1601,7 @@ class AudioLM(nn.Module): if self.needs_text: if exists(text): text_embeds = self.semantic.embed_text(texts) text_embeds = self.semantic.embed_text(text) if exists(prime_wave): prime_wave = prime_wave.to(self.device) Loading Loading
audiolm_pytorch/audiolm_pytorch.py +2 −1 Original line number Diff line number Diff line Loading @@ -1016,6 +1016,7 @@ class SemanticTransformerWrapper(nn.Module): logits = self.transformer.forward_with_cond_scale( ids = sample_semantic_ids, text_embeds = text_embeds, cond_scale = cond_scale, **kwargs ) Loading Loading @@ -1600,7 +1601,7 @@ class AudioLM(nn.Module): if self.needs_text: if exists(text): text_embeds = self.semantic.embed_text(texts) text_embeds = self.semantic.embed_text(text) if exists(prime_wave): prime_wave = prime_wave.to(self.device) Loading