Loading audiolm_pytorch/audiolm_pytorch.py +3 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME from audiolm_pytorch.utils import curtail_to_multiple from torchaudio.functional import resample Loading Loading @@ -374,6 +375,8 @@ class SoundStream(nn.Module): if exists(input_sample_khz): x = resample(x, input_sample_khz, self.target_sample_khz) x = curtail_to_multiple(x, self.seq_len_multiple_of) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading audiolm_pytorch/data.py +3 −3 Original line number Diff line number Diff line Loading @@ -6,6 +6,8 @@ import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from audiolm_pytorch.utils import curtail_to_multiple from einops import rearrange def exists(val): Loading Loading @@ -45,9 +47,7 @@ class SoundDataset(Dataset): data = data[:self.max_length] if exists(self.seq_len_multiple_of): mult = self.seq_len_multiple_of data_len = len(data) data = data[:(data_len // mult * mult)] data = curtail_to_multiple(data, self.seq_len_multiple_of) return data.float() Loading audiolm_pytorch/hubert_kmeans.py +8 −1 Original line number Diff line number Diff line Loading @@ -9,6 +9,8 @@ import fairseq from torchaudio.functional import resample from audiolm_pytorch.utils import curtail_to_multiple def exists(val): return val is not None Loading @@ -17,10 +19,12 @@ class HubertWithKmeans(nn.Module): self, checkpoint_path, kmeans_path, target_sample_khz = 50000 target_sample_khz = 50000, seq_len_multiple_of = None ): super().__init__() self.target_sample_khz = target_sample_khz self.seq_len_multiple_of = seq_len_multiple_of model_path = Path(checkpoint_path) kmeans_path = Path(kmeans_path) Loading Loading @@ -58,6 +62,9 @@ class HubertWithKmeans(nn.Module): if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) if exists(self.seq_len_multiple_of): wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) embed = self.model(wav_input, features_only = True) embed, packed_shape = pack([embed['x']], '* d') Loading audiolm_pytorch/utils.py 0 → 100644 +3 −0 Original line number Diff line number Diff line def curtail_to_multiple(t, mult): data_len = t.shape[-1] return t[..., :(data_len // mult * mult)] audiolm_pytorch/vq_wav2vec.py +8 −1 Original line number Diff line number Diff line Loading @@ -8,6 +8,8 @@ import fairseq from torchaudio.functional import resample from audiolm_pytorch.utils import curtail_to_multiple def exists(val): return val is not None Loading @@ -15,10 +17,12 @@ class FairseqVQWav2Vec(nn.Module): def __init__( self, checkpoint_path, target_sample_khz = 24000 target_sample_khz = 24000, seq_len_multiple_of = None ): super().__init__() self.target_sample_khz = target_sample_khz self.seq_len_multiple_of = seq_len_multiple_of path = Path(checkpoint_path) assert path.exists(), f'path {checkpoint_path} does not exist' Loading Loading @@ -48,6 +52,9 @@ class FairseqVQWav2Vec(nn.Module): if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) if exists(self.seq_len_multiple_of): wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) embed = self.model.feature_extractor(wav_input) _, codebook_indices = self.model.vector_quantizer.forward_idx(embed) Loading Loading
audiolm_pytorch/audiolm_pytorch.py +3 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME from audiolm_pytorch.utils import curtail_to_multiple from torchaudio.functional import resample Loading Loading @@ -374,6 +375,8 @@ class SoundStream(nn.Module): if exists(input_sample_khz): x = resample(x, input_sample_khz, self.target_sample_khz) x = curtail_to_multiple(x, self.seq_len_multiple_of) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading
audiolm_pytorch/data.py +3 −3 Original line number Diff line number Diff line Loading @@ -6,6 +6,8 @@ import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from audiolm_pytorch.utils import curtail_to_multiple from einops import rearrange def exists(val): Loading Loading @@ -45,9 +47,7 @@ class SoundDataset(Dataset): data = data[:self.max_length] if exists(self.seq_len_multiple_of): mult = self.seq_len_multiple_of data_len = len(data) data = data[:(data_len // mult * mult)] data = curtail_to_multiple(data, self.seq_len_multiple_of) return data.float() Loading
audiolm_pytorch/hubert_kmeans.py +8 −1 Original line number Diff line number Diff line Loading @@ -9,6 +9,8 @@ import fairseq from torchaudio.functional import resample from audiolm_pytorch.utils import curtail_to_multiple def exists(val): return val is not None Loading @@ -17,10 +19,12 @@ class HubertWithKmeans(nn.Module): self, checkpoint_path, kmeans_path, target_sample_khz = 50000 target_sample_khz = 50000, seq_len_multiple_of = None ): super().__init__() self.target_sample_khz = target_sample_khz self.seq_len_multiple_of = seq_len_multiple_of model_path = Path(checkpoint_path) kmeans_path = Path(kmeans_path) Loading Loading @@ -58,6 +62,9 @@ class HubertWithKmeans(nn.Module): if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) if exists(self.seq_len_multiple_of): wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) embed = self.model(wav_input, features_only = True) embed, packed_shape = pack([embed['x']], '* d') Loading
audiolm_pytorch/utils.py 0 → 100644 +3 −0 Original line number Diff line number Diff line def curtail_to_multiple(t, mult): data_len = t.shape[-1] return t[..., :(data_len // mult * mult)]
audiolm_pytorch/vq_wav2vec.py +8 −1 Original line number Diff line number Diff line Loading @@ -8,6 +8,8 @@ import fairseq from torchaudio.functional import resample from audiolm_pytorch.utils import curtail_to_multiple def exists(val): return val is not None Loading @@ -15,10 +17,12 @@ class FairseqVQWav2Vec(nn.Module): def __init__( self, checkpoint_path, target_sample_khz = 24000 target_sample_khz = 24000, seq_len_multiple_of = None ): super().__init__() self.target_sample_khz = target_sample_khz self.seq_len_multiple_of = seq_len_multiple_of path = Path(checkpoint_path) assert path.exists(), f'path {checkpoint_path} does not exist' Loading Loading @@ -48,6 +52,9 @@ class FairseqVQWav2Vec(nn.Module): if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) if exists(self.seq_len_multiple_of): wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) embed = self.model.feature_extractor(wav_input) _, codebook_indices = self.model.vector_quantizer.forward_idx(embed) Loading