Commit a17b189a authored by Phil Wang's avatar Phil Wang
Browse files

be able to reconstruct the coarse wav from coarse transformer to soundstream decoder

parent 684da45c
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -862,7 +862,8 @@ class CoarseTransformerWrapper(nn.Module):
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        reshape_output = True
        reshape_output = True,
        reconstruct_wave = False
    ):
        batch, device = semantic_token_ids.shape[0], self.device

@@ -897,9 +898,15 @@ class CoarseTransformerWrapper(nn.Module):

        output = mask_out_after_eos_id(output, self.eos_id, include_eos = False)

        if reshape_output:
        if reshape_output or reconstruct_wave:
            output = rearrange(output, 'b (n q) -> b n q', q = self.num_coarse_quantizers)

        if reconstruct_wave:
            assert exists(self.soundstream)
            wav = self.soundstream.decode_from_codebook_indices(output)
            wav = rearrange(wav, 'b 1 n -> b n')
            return wav

        return output

    def forward(
+7 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, reduce

from vector_quantize_pytorch import ResidualVQ

@@ -310,6 +310,12 @@ class SoundStream(nn.Module):
        self.adversarial_loss_weight = adversarial_loss_weight
        self.feature_loss_weight = feature_loss_weight

    def decode_from_codebook_indices(self, quantized_indices):
        codes = self.rq.get_codes_from_indices(quantized_indices)
        x = reduce(codes, 'q ... -> ...', 'sum')
        x = rearrange(x, 'b n c -> b c n')
        return self.decoder(x)

    def load(self, path):
        path = Path(path)
        assert path.exists()
+2 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.40',
  version = '0.0.41',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -27,7 +27,7 @@ setup(
    'torchaudio',
    'transformers',
    'typeguard',
    'vector-quantize-pytorch>=0.10.10'
    'vector-quantize-pytorch>=0.10.11'
  ],
  classifiers=[
    'Development Status :: 4 - Beta',