Commit 92376359 authored by Phil Wang's avatar Phil Wang
Browse files

make n_fft and n_mels complete configurable for soundstream for multi spectral recon loss

parent 6123c770
Loading
Loading
Loading
Loading
+17 −3
Original line number Diff line number Diff line
import functools
from itertools import cycle
from pathlib import Path

from functools import partial, wraps
from itertools import zip_longest

import torch
from torch import nn, einsum
@@ -28,6 +30,9 @@ def exists(val):
def default(val, d):
    return val if exists(val) else d

def cast_tuple(t, l = 1):
    return ((t,) * l) if not isinstance(t, tuple) else t

# tensor helpers

def l2norm(t, dim = -1):
@@ -450,7 +455,8 @@ class SoundStream(nn.Module):
        enc_cycle_dilations = (1, 3, 9),
        dec_cycle_dilations = (1, 3, 9),
        multi_spectral_window_powers_of_two = tuple(range(6, 12)),
        multi_spectral_n_mels = (8, 16, 32, 64, 64, 64, 64),
        multi_spectral_n_ffts = 512,
        multi_spectral_n_mels = 64,
        recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1.,
        adversarial_loss_weight = 1.,
@@ -547,13 +553,21 @@ class SoundStream(nn.Module):
        self.mel_spec_transforms = nn.ModuleList([])
        self.mel_spec_recon_alphas = []

        for powers, n_mels in zip(multi_spectral_window_powers_of_two, multi_spectral_n_mels):
        num_transforms = len(multi_spectral_window_powers_of_two)
        multi_spectral_n_ffts = cast_tuple(multi_spectral_n_ffts, num_transforms)
        multi_spectral_n_mels = cast_tuple(multi_spectral_n_mels, num_transforms)

        for powers, n_fft, n_mels in zip_longest(multi_spectral_window_powers_of_two, multi_spectral_n_ffts, multi_spectral_n_mels):
            win_length = 2 ** powers
            alpha = (win_length / 2) ** 0.5

            calculated_n_fft = default(max(n_fft, win_length), win_length)  # @AndreyBocharnikov said this is usually win length, but overridable

            # if any audio experts have an opinion about these settings, please submit a PR

            melspec_transform = T.MelSpectrogram(
                sample_rate = target_sample_hz,
                n_fft = win_length,
                n_fft = calculated_n_fft,
                win_length = win_length,
                hop_length = win_length // 4,
                n_mels = n_mels,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.11.8',
  version = '0.11.9',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',