Commit 21a07ab9 authored by Phil Wang's avatar Phil Wang
Browse files

get the default parameters for soundstream, as well as AudioLMSoundStream and...

get the default parameters for soundstream, as well as AudioLMSoundStream and MusicLMSoundStream correct, addressing https://github.com/lucidrains/audiolm-pytorch/issues/110
parent 9b8e702f
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -69,6 +69,16 @@ audio = torch.randn(10080).cuda()
recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel
```

You can also use soundstreams that are specific to `AudioLM` and `MusicLM` by importing `AudioLMSoundStream` and `MusicLMSoundStream` respectively

```python
from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream

soundstream = AudioLMSoundStream(...) # say you want the hyperparameters as in Audio LM paper

# rest is the same as above
```

Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained


+5 −3
Original line number Diff line number Diff line
@@ -348,7 +348,7 @@ class SoundStream(nn.Module):
        self,
        *,
        channels = 32,
        strides = (3, 4, 5, 8),
        strides = (2, 4, 5, 8),
        channel_mults = (2, 4, 8, 16),
        codebook_dim = 512,
        codebook_size = 1024,
@@ -368,7 +368,7 @@ class SoundStream(nn.Module):
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 24000,
        target_sample_hz = 16000,
        use_local_attn = True,
        use_mhesa = True,
        mhesa_heads = 4,
@@ -716,12 +716,13 @@ class AudioLMSoundStream(SoundStream):
        self,
        strides = (2, 4, 5, 8),
        target_sample_hz = 16000,
        rq_num_quantizers = 8,
        rq_num_quantizers = 12,
        **kwargs
    ):
        super().__init__(
            strides = strides,
            target_sample_hz = target_sample_hz,
            rq_num_quantizers = rq_num_quantizers,
            **kwargs
        )

@@ -736,5 +737,6 @@ class MusicLMSoundStream(SoundStream):
        super().__init__(
            strides = strides,
            target_sample_hz = target_sample_hz,
            rq_num_quantizers = rq_num_quantizers,
            **kwargs
        )
+1 −1
Original line number Diff line number Diff line
__version__ = '0.15.9'
__version__ = '0.16.1'