@@ -917,6 +936,8 @@ class SemanticTransformerWrapper(nn.Module):
self.transformer=transformer
self.audio_conditioner=audio_conditioner
assertnot(exists(audio_conditioner)andnottransformer.has_condition),'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'
assertnotexists(self.wav2vec)orself.wav2vec.codebook_size==transformer.num_semantic_tokens,f'num_semantic_tokens on SemanticTransformer must be set to {self.wav2vec.codebook_size}'
self.unique_consecutive=unique_consecutive
@@ -1088,9 +1109,12 @@ class CoarseTransformerWrapper(nn.Module):
super().__init__()
self.soundstream=soundstream
self.wav2vec=wav2vec
self.audio_conditioner=audio_conditioner
self.transformer=transformer
self.audio_conditioner=audio_conditioner
assertnot(exists(audio_conditioner)andnottransformer.has_condition),'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'
self.unique_consecutive=unique_consecutive
self.pad_id=pad_id
@@ -1293,9 +1317,12 @@ class FineTransformerWrapper(nn.Module):
):
super().__init__()
self.soundstream=soundstream
self.transformer=transformer
self.audio_conditioner=audio_conditioner
assertnot(exists(audio_conditioner)andnottransformer.has_condition),'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'