Commit 5b24b4f5 authored by Phil Wang's avatar Phil Wang
Browse files

handle if any of the models requires the sequence length to be some multiple of

parent 02902731
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
from audiolm_pytorch.utils import curtail_to_multiple

from torchaudio.functional import resample

@@ -374,6 +375,8 @@ class SoundStream(nn.Module):
        if exists(input_sample_khz):
            x = resample(x, input_sample_khz, self.target_sample_khz)

        x = curtail_to_multiple(x, self.seq_len_multiple_of)

        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')

+3 −3
Original line number Diff line number Diff line
@@ -6,6 +6,8 @@ import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

from audiolm_pytorch.utils import curtail_to_multiple

from einops import rearrange

def exists(val):
@@ -45,9 +47,7 @@ class SoundDataset(Dataset):
            data = data[:self.max_length]

        if exists(self.seq_len_multiple_of):
            mult = self.seq_len_multiple_of
            data_len = len(data)
            data = data[:(data_len // mult * mult)]
            data = curtail_to_multiple(data, self.seq_len_multiple_of)

        return data.float()

+8 −1
Original line number Diff line number Diff line
@@ -9,6 +9,8 @@ import fairseq

from torchaudio.functional import resample

from audiolm_pytorch.utils import curtail_to_multiple

def exists(val):
    return val is not None

@@ -17,10 +19,12 @@ class HubertWithKmeans(nn.Module):
        self,
        checkpoint_path,
        kmeans_path,
        target_sample_khz = 50000
        target_sample_khz = 50000,
        seq_len_multiple_of = None
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz
        self.seq_len_multiple_of = seq_len_multiple_of

        model_path = Path(checkpoint_path)
        kmeans_path = Path(kmeans_path)
@@ -58,6 +62,9 @@ class HubertWithKmeans(nn.Module):
        if exists(input_sample_khz):
            wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)

        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

        embed = self.model(wav_input, features_only = True)
        embed, packed_shape = pack([embed['x']], '* d')

+3 −0
Original line number Diff line number Diff line
def curtail_to_multiple(t, mult):
    data_len = t.shape[-1]
    return t[..., :(data_len // mult * mult)]
+8 −1
Original line number Diff line number Diff line
@@ -8,6 +8,8 @@ import fairseq

from torchaudio.functional import resample

from audiolm_pytorch.utils import curtail_to_multiple

def exists(val):
    return val is not None

@@ -15,10 +17,12 @@ class FairseqVQWav2Vec(nn.Module):
    def __init__(
        self,
        checkpoint_path,
        target_sample_khz = 24000
        target_sample_khz = 24000,
        seq_len_multiple_of = None
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz
        self.seq_len_multiple_of = seq_len_multiple_of

        path = Path(checkpoint_path)
        assert path.exists(), f'path {checkpoint_path} does not exist'
@@ -48,6 +52,9 @@ class FairseqVQWav2Vec(nn.Module):
        if exists(input_sample_khz):
            wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)

        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

        embed = self.model.feature_extractor(wav_input)
        _, codebook_indices = self.model.vector_quantizer.forward_idx(embed)

Loading