@@ -744,6 +798,68 @@ class FineTransformerWrapper(nn.Module):
return (coarse_loss+fine_loss)*0.5
classCoarseTransformerWrapper(nn.Module):
def__init__(
self,
*,
soundstream:Optional[SoundStream],
transformer:FineTransformer,
num_coarse_quantize=3
):
super().__init__()
self.soundstream=soundstream
self.transformer=transformer
assertnum_coarse_quantize>0
self.num_coarse_quantize=num_coarse_quantize
defforward(
self,
*,
semantic_token_ids,
raw_wave=None,
coarse_token_ids=None,
return_loss=False
):
assertexists(raw_wave)^exists(coarse_token_ids),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
ifexists(raw_wave):
assertexists(self.soundstream),'SoundStream must be provided if given raw wave for training'