Commit 02902731 authored by Phil Wang's avatar Phil Wang
Browse files

make sure unconditional synthesis can still work, add ability to resample...

make sure unconditional synthesis can still work, add ability to resample input wave on the fly given input sampling frequencies is supplied
parent 2725ae89
Loading
Loading
Loading
Loading
+26 −14
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@ from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

from torchaudio.functional import resample

# helper functions

def exists(val):
@@ -295,9 +297,12 @@ class SoundStream(nn.Module):
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout = True,
        quantize_dropout_cutoff_index = 0
        quantize_dropout_cutoff_index = 0,
        target_sample_khz = 24000
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz # for resampling on the fly

        self.single_channel = input_channels == 1
        self.strides = strides

@@ -363,8 +368,12 @@ class SoundStream(nn.Module):
        return_encoded = False,
        return_discr_loss = False,
        return_discr_losses_separately = False,
        return_recons_only = False
        return_recons_only = False,
        input_sample_khz = None
    ):
        if exists(input_sample_khz):
            x = resample(x, input_sample_khz, self.target_sample_khz)

        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')

@@ -699,7 +708,7 @@ class SemanticTransformer(nn.Module):
        ids = None,
        return_loss = False,
        text = None,
        text_embed = None,
        text_embeds = None,
        cond_drop_prob = None
    ):
        device = next(self.parameters()).device
@@ -717,17 +726,18 @@ class SemanticTransformer(nn.Module):
        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        has_text = exists(text) or exists(text_embed)
        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embed):
        text_mask = None
        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if cond_drop_prob > 0:
        if exists(text_mask) and cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

@@ -798,22 +808,23 @@ class CoarseTransformer(nn.Module):
        coarse_token_ids,
        self_attn_mask = None,
        text = None,
        text_embed = None,
        text_embeds = None,
        cond_drop_prob = None
    ):
        b, device = semantic_token_ids.shape[0], semantic_token_ids.device

        has_text = exists(text) or exists(text_embed)
        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embed):
        text_mask = None
        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if cond_drop_prob > 0:
        if exists(text_mask) and cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

@@ -906,21 +917,22 @@ class FineTransformer(nn.Module):
        coarse_token_ids,
        fine_token_ids,
        text = None,
        text_embed = None,
        text_embeds = None,
        cond_drop_prob = None
    ):
        b, device = coarse_token_ids.shape[0], coarse_token_ids.device
        has_text = exists(text) or exists(text_embed)
        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embed):
        text_mask = None
        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if cond_drop_prob > 0:
        if exists(text_mask) and cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

+18 −2
Original line number Diff line number Diff line
@@ -7,13 +7,21 @@ from einops import rearrange, pack, unpack
import joblib
import fairseq

from torchaudio.functional import resample

def exists(val):
    return val is not None

class HubertWithKmeans(nn.Module):
    def __init__(
        self,
        checkpoint_path,
        kmeans_path
        kmeans_path,
        target_sample_khz = 50000
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz

        model_path = Path(checkpoint_path)
        kmeans_path = Path(kmeans_path)

@@ -39,9 +47,17 @@ class HubertWithKmeans(nn.Module):
        return self.kmeans.n_clusters

    @torch.no_grad()
    def forward(self, wav_input, flatten = True):
    def forward(
        self,
        wav_input,
        flatten = True,
        input_sample_khz = None
    ):
        device = wav_input.device

        if exists(input_sample_khz):
            wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)

        embed = self.model(wav_input, features_only = True)
        embed, packed_shape = pack([embed['x']], '* d')

+18 −2
Original line number Diff line number Diff line
@@ -6,12 +6,20 @@ from einops import rearrange

import fairseq

from torchaudio.functional import resample

def exists(val):
    return val is not None

class FairseqVQWav2Vec(nn.Module):
    def __init__(
        self,
        checkpoint_path
        checkpoint_path,
        target_sample_khz = 24000
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz

        path = Path(checkpoint_path)
        assert path.exists(), f'path {checkpoint_path} does not exist'

@@ -31,7 +39,15 @@ class FairseqVQWav2Vec(nn.Module):
        return self.model.vector_quantizer.embedding.shape[0]

    @torch.no_grad()
    def forward(self, wav_input, flatten = True):
    def forward(
        self,
        wav_input,
        flatten = True,
        input_sample_khz = None
    ):
        if exists(input_sample_khz):
            wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)

        embed = self.model.feature_extractor(wav_input)
        _, codebook_indices = self.model.vector_quantizer.forward_idx(embed)

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.20',
  version = '0.0.21',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',