Commit 95b8cd23 authored by Phil Wang's avatar Phil Wang
Browse files

oops

parent b6e5af78
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -304,10 +304,10 @@ class SoundStream(nn.Module):
        feature_loss_weight = 100,
        quantize_dropout = True,
        quantize_dropout_cutoff_index = 0,
        target_sample_khz = 24000
        target_sample_hz = 24000
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz # for resampling on the fly
        self.target_sample_hz = target_sample_hz # for resampling on the fly

        self.single_channel = input_channels == 1
        self.strides = strides
@@ -375,10 +375,10 @@ class SoundStream(nn.Module):
        return_discr_loss = False,
        return_discr_losses_separately = False,
        return_recons_only = False,
        input_sample_khz = None
        input_sample_hz = None
    ):
        if exists(input_sample_khz):
            x = resample(x, input_sample_khz, self.target_sample_khz)
        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

        x = curtail_to_multiple(x, self.seq_len_multiple_of)

+5 −5
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ class SoundDataset(Dataset):
        folder,
        exts = ['flac', 'wav'],
        max_length = None,
        target_sample_khz = None,
        target_sample_hz = None,
        seq_len_multiple_of = None
    ):
        super().__init__()
@@ -34,7 +34,7 @@ class SoundDataset(Dataset):
        self.files = files
        self.max_length = max_length

        self.target_sample_khz = target_sample_khz
        self.target_sample_hz = target_sample_hz
        self.seq_len_multiple_of = seq_len_multiple_of

    def __len__(self):
@@ -42,12 +42,12 @@ class SoundDataset(Dataset):

    def __getitem__(self, idx):
        file = self.files[idx]
        data, sample_khz = torchaudio.load(file)
        data, sample_hz = torchaudio.load(file)

        data = rearrange(data, '1 ... -> ...')

        if exists(self.target_sample_khz):
            data = torchaudio.functional.resample(data, sample_khz, self.target_sample_khz)
        if exists(self.target_sample_hz):
            data = torchaudio.functional.resample(data, sample_hz, self.target_sample_hz)

        if exists(self.max_length):
            data = data[:self.max_length]
+5 −5
Original line number Diff line number Diff line
@@ -19,11 +19,11 @@ class HubertWithKmeans(nn.Module):
        self,
        checkpoint_path,
        kmeans_path,
        target_sample_khz = 50000,
        target_sample_hz = 50000,
        seq_len_multiple_of = None
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz
        self.target_sample_hz = target_sample_hz
        self.seq_len_multiple_of = seq_len_multiple_of

        model_path = Path(checkpoint_path)
@@ -55,12 +55,12 @@ class HubertWithKmeans(nn.Module):
        self,
        wav_input,
        flatten = True,
        input_sample_khz = None
        input_sample_hz = None
    ):
        device = wav_input.device

        if exists(input_sample_khz):
            wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)
        if exists(input_sample_hz):
            wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
+5 −5
Original line number Diff line number Diff line
@@ -17,11 +17,11 @@ class FairseqVQWav2Vec(nn.Module):
    def __init__(
        self,
        checkpoint_path,
        target_sample_khz = 24000,
        target_sample_hz = 24000,
        seq_len_multiple_of = None
    ):
        super().__init__()
        self.target_sample_khz = target_sample_khz
        self.target_sample_hz = target_sample_hz
        self.seq_len_multiple_of = seq_len_multiple_of

        path = Path(checkpoint_path)
@@ -47,10 +47,10 @@ class FairseqVQWav2Vec(nn.Module):
        self,
        wav_input,
        flatten = True,
        input_sample_khz = None
        input_sample_hz = None
    ):
        if exists(input_sample_khz):
            wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)
        if exists(input_sample_hz):
            wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.24',
  version = '0.0.25',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',