@@ -924,6 +915,8 @@ class SemanticTransformerWrapper(nn.Module):
super().__init__()
self.wav2vec=wav2vec
self.transformer=transformer
self.audio_conditioner=audio_conditioner
assertnotexists(self.wav2vec)orself.wav2vec.codebook_size==transformer.num_semantic_tokens,f'num_semantic_tokens on SemanticTransformer must be set to {self.wav2vec.codebook_size}'
self.unique_consecutive=unique_consecutive
@@ -969,6 +962,12 @@ class SemanticTransformerWrapper(nn.Module):
@@ -1391,6 +1408,8 @@ class FineTransformerWrapper(nn.Module):
self,
*,
raw_wave=None,
text=None,
text_embeds=None,
token_ids=None,
coarse_token_ids=None,
fine_token_ids=None,
@@ -1399,6 +1418,11 @@ class FineTransformerWrapper(nn.Module):
):
assertexists(raw_wave)^(exists(token_ids)^(exists(coarse_token_ids)andexists(fine_token_ids))),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
ifexists(self.audio_conditioner):
assertexists(raw_wave)
assertnotexists(text)andnotexists(text_embeds)
text_embeds=self.audio_conditioner(raw_wave)# technically audio embeds, but shared text-audio joint embedding space for mulan
ifexists(raw_wave):
assertexists(self.soundstream),'SoundStream must be provided if given raw wave for training'
@@ -1432,6 +1456,8 @@ class FineTransformerWrapper(nn.Module):