semantic token ids will have variable lengths because of unique consecutive,...
semantic token ids will have variable lengths because of unique consecutive, so eos token must be manually selected and then used to predict the first coarse token, in the coarse transformer
rearrange(pred_semantic_eos_tokens,'b d -> b 1 d'),
pred_coarse_tokens),
dim=1)
# semantic logits
@@ -789,74 +800,6 @@ class FineTransformer(nn.Module):
# training wrappers
classFineTransformerWrapper(nn.Module):
def__init__(
self,
*,
transformer:FineTransformer,
soundstream:Optional[SoundStream]=None,
num_coarse_quantize=3
):
super().__init__()
self.soundstream=soundstream
self.transformer=transformer
assertnum_coarse_quantize>0
self.num_coarse_quantize=num_coarse_quantize
defforward(
self,
*,
raw_wave=None,
coarse_token_ids=None,
fine_token_ids=None,
return_loss=False,
**kwargs
):
assertexists(raw_wave)^(exists(coarse_token_ids)andexists(fine_token_ids)),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
ifexists(raw_wave):
assertexists(self.soundstream),'SoundStream must be provided if given raw wave for training'
assertexists(raw_wave)^(exists(coarse_token_ids)andexists(fine_token_ids)),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
ifexists(raw_wave):
assertexists(self.soundstream),'SoundStream must be provided if given raw wave for training'