-<ahref="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research
-<ahref="https://huggingface.co/">🤗 Huggingface</a> for their amazing accelerate library
-<ahref="https://huggingface.co/">🤗 Huggingface</a> for their amazing accelerate and transformers libraries
-<ahref="https://ai.facebook.com/">MetaAI</a> for <ahref="https://github.com/facebookresearch/fairseq">Fairseq</a> and the liberal license
@@ -58,6 +58,7 @@ loss.backward()
- [x] complete CoarseTransformer
- [x] use fairseq vq-wav2vec for embeddings
- [x] add conditioning
- [ ] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <ahref="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [ ] complete full training code for soundstream, taking care of discriminator training
@@ -69,7 +70,8 @@ loss.backward()
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
- [ ] DRY a little at the end
- [ ] figure out how to suppress logging in fairseq
- [ ] test with speech synthesis for starters, add conditioning + classifier free guidance as well
@@ -794,7 +895,8 @@ class FineTransformerWrapper(nn.Module):
raw_wave=None,
coarse_token_ids=None,
fine_token_ids=None,
return_loss=False
return_loss=False,
**kwargs
):
assertexists(raw_wave)^(exists(coarse_token_ids)andexists(fine_token_ids)),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
@@ -815,7 +917,8 @@ class FineTransformerWrapper(nn.Module):
coarse_logits,fine_logits=self.transformer(
coarse_token_ids=coarse_token_ids,
fine_token_ids=fine_token_ids
fine_token_ids=fine_token_ids,
**kwargs
)
ifnotreturn_loss:
@@ -859,7 +962,8 @@ class CoarseTransformerWrapper(nn.Module):
semantic_token_ids=None,
raw_wave=None,
coarse_token_ids=None,
return_loss=False
return_loss=False,
**kwargs
):
assertexists(raw_wave)orexists(semantic_token_ids),'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'
assertexists(raw_wave)orexists(coarse_token_ids),'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
@@ -886,7 +990,8 @@ class CoarseTransformerWrapper(nn.Module):