Commit 71fe13fb authored by Phil Wang's avatar Phil Wang
Browse files

add ability to train with grouped residual vq, from hifi-codec paper

parent 5183576e
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -484,3 +484,11 @@ $ accelerate launch train.py
    pages   = {7132-7141}
}
```

```bibtex
@inproceedings{Yang2023HiFiCodecGV,
    title   = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
    author  = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
    year    = {2023}
}
```
+4 −4
Original line number Diff line number Diff line
@@ -1293,7 +1293,7 @@ class CoarseTransformerWrapper(nn.Module):

        self.semantic_cross_entropy_loss_weight = semantic_cross_entropy_loss_weight

        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        self.num_coarse_quantizers = transformer.num_coarse_quantizers * codec.rq_groups
        self.semantic_eos_id = transformer.semantic_eos_id
        self.coarse_eos_id = transformer.coarse_eos_id

@@ -1506,11 +1506,11 @@ class FineTransformerWrapper(nn.Module):

        assert not (exists(audio_conditioner) and not transformer.has_condition), 'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'

        self.num_fine_quantizers = transformer.num_fine_quantizers
        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        self.num_fine_quantizers = transformer.num_fine_quantizers * codec.rq_groups
        self.num_coarse_quantizers = transformer.num_coarse_quantizers * codec.rq_groups

        if exists(codec):
            assert (self.num_fine_quantizers + self.num_coarse_quantizers) == codec.num_quantizers, 'number of fine and coarse quantizers on fine transformer must add up to total number of quantizers on codec'
            assert (self.num_fine_quantizers + self.num_coarse_quantizers) == (codec.num_quantizers * codec.rq_groups), 'number of fine and coarse quantizers on fine transformer must add up to total number of quantizers on codec'

        self.eos_id = transformer.eos_id

+1 −0
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@ class EncodecWrapper(nn.Module):
        assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet"

        self.codebook_dim = 128
        self.rq_groups = 1
        self.num_quantizers = num_quantizers
        self.strides = strides # used in seq_len_multiple_of

+10 −3
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from torchaudio.functional import resample

from einops import rearrange, reduce, pack, unpack

from vector_quantize_pytorch import ResidualVQ
from vector_quantize_pytorch import GroupedResidualVQ

from local_attention import LocalMHA
from local_attention.transformer import FeedForward, DynamicPositionBias
@@ -426,6 +426,7 @@ class SoundStream(nn.Module):
        rq_commitment_weight = 1.,
        rq_ema_decay = 0.95,
        rq_quantize_dropout_multiple_of = 1,
        rq_groups = 1,
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        stft_normalized = False,
@@ -502,10 +503,13 @@ class SoundStream(nn.Module):

        self.codebook_dim = codebook_dim

        self.rq = ResidualVQ(
        self.rq_groups = rq_groups

        self.rq = GroupedResidualVQ(
            dim = codebook_dim,
            num_quantizers = rq_num_quantizers,
            codebook_size = codebook_size,
            groups = rq_groups,
            decay = rq_ema_decay,
            commitment_weight = rq_commitment_weight,
            quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of,
@@ -592,8 +596,10 @@ class SoundStream(nn.Module):
        return pickle.loads(self._configs)

    def decode_from_codebook_indices(self, quantized_indices):
        quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups)

        codes = self.rq.get_codes_from_indices(quantized_indices)
        x = reduce(codes, 'q ... -> ...', 'sum')
        x = reduce(codes, 'g q b n d -> b n (g d)', 'sum')

        return self.decode(x)

@@ -716,6 +722,7 @@ class SoundStream(nn.Module):
        x, indices, commit_loss = self.rq(x)

        if return_encoded:
            indices = rearrange(indices, 'g b n q -> b n (g q)')
            return x, indices, commit_loss

        if exists(is_denoising):
+1 −1
Original line number Diff line number Diff line
__version__ = '0.29.0'
__version__ = '0.30.0'
Loading