@@ -612,6 +626,7 @@ class CoarseTransformer(nn.Module):
defforward(
self,
*,
semantic_token_ids,
coarse_token_ids,
):
@@ -825,10 +840,13 @@ class CoarseTransformerWrapper(nn.Module):
*,
transformer:FineTransformer,
soundstream:Optional[SoundStream]=None,
wav2vec:Optional[FairseqVQWav2Vec]=None,
num_coarse_quantize=3
):
super().__init__()
self.soundstream=soundstream
self.wav2vec=wav2vec
self.transformer=transformer
assertnum_coarse_quantize>0
@@ -837,14 +855,20 @@ class CoarseTransformerWrapper(nn.Module):
defforward(
self,
*,
semantic_token_ids,
semantic_token_ids=None,
raw_wave=None,
coarse_token_ids=None,
return_loss=False
):
assertexists(raw_wave)^exists(coarse_token_ids),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
assertexists(raw_wave)orexists(semantic_token_ids),'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'
assertexists(raw_wave)orexists(coarse_token_ids),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'