@@ -954,13 +954,17 @@ class CoarseTransformerWrapper(nn.Module):
*,
semantic_token_ids=None,
raw_wave=None,
raw_wave_for_soundstream=None,
coarse_token_ids=None,
return_loss=False,
**kwargs
):
assertexists(raw_wave)orexists(semantic_token_ids),'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'
assertexists(raw_wave)orexists(coarse_token_ids),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
assertexists(raw_wave_for_soundstream)orexists(coarse_token_ids),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'