Commit 018a37c9 authored by zhvng's avatar zhvng
Browse files

small bug fixes

parent 09f453fc
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -1016,6 +1016,7 @@ class SemanticTransformerWrapper(nn.Module):
            logits = self.transformer.forward_with_cond_scale(
                ids = sample_semantic_ids,
                text_embeds = text_embeds,
                cond_scale = cond_scale,
                **kwargs
            )

@@ -1597,7 +1598,7 @@ class AudioLM(nn.Module):

        if self.needs_text:
            if exists(text):
                text_embeds = self.semantic.embed_text(texts)
                text_embeds = self.semantic.embed_text(text)

        if exists(prime_wave):
            prime_wave = prime_wave.to(self.device)