Commit 846740c3 authored by Phil Wang's avatar Phil Wang
Browse files

forget about weight decay, and make sure to clip wavs by data_max_length

parent 096a21b8
Loading
Loading
Loading
Loading
+22 −6
Original line number Diff line number Diff line
from functools import wraps

import torch
import torch.nn.functional as F
from torch import nn, einsum
@@ -30,6 +32,21 @@ def round_down_nearest_multiple(n, divisor):
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# decorators

def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

print_once = once(print)

# tensor functions

def log(t, eps = 1e-20):
@@ -240,8 +257,7 @@ class AudioSpectrogramTransformer(nn.Module):
        spec_aug_stretch_factor = 0.8,
        spec_aug_freq_mask = 80,
        spec_aug_time_mask = 80,
        dual_patchnorm = True,
        patch_dropout_prob = 0.5
        patch_dropout_prob = 0.25
    ):
        super().__init__()
        self.dim = dim
@@ -251,9 +267,9 @@ class AudioSpectrogramTransformer(nn.Module):

        self.to_patch_tokens = Sequential(
            Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
            nn.LayerNorm(patch_input_dim) if dual_patchnorm else None,
            nn.LayerNorm(patch_input_dim),
            nn.Linear(patch_input_dim, dim),
            nn.LayerNorm(dim) if dual_patchnorm else None
            nn.LayerNorm(dim)
        )

        self.spec = Spectrogram(
@@ -302,7 +318,7 @@ class AudioSpectrogramTransformer(nn.Module):
        rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))

        if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed
            print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')
            print_once(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')

        x = x[..., :rounded_height, :rounded_width]

+12 −30
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ from beartype.typing import Union, List, Optional, Tuple, Callable

import torch
from torch import nn
from torch.optim import AdamW, Adam
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

@@ -101,32 +101,6 @@ def separate_weight_decayable_params(params):
        param_list.append(param)
    return wd_params, no_wd_params

def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    **kwargs
):
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    if group_wd_params:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

# dataloader functions

def collate_one_or_multiple_tensors(fn):
@@ -180,7 +154,7 @@ class MuLaNTrainer(nn.Module):
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
        betas = (0.9, 0.99),
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
@@ -201,12 +175,14 @@ class MuLaNTrainer(nn.Module):

        # optimizers

        self.optim = get_optimizer(mulan.parameters(), lr = lr, wd = wd)
        self.optim = Adam(mulan.parameters(), lr = lr, betas = betas)

        # max grad norm

        self.max_grad_norm = max_grad_norm

        self.data_max_length = data_max_length

        # create dataset

        self.ds = dataset
@@ -311,7 +287,12 @@ class MuLaNTrainer(nn.Module):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'

        return dict(zip(self.ds_fields, data))
        data_kwargs =  dict(zip(self.ds_fields, data))

        wavs = data_kwargs['wavs']
        data_kwargs.update(wavs = wavs[..., :self.data_max_length])

        return data_kwargs

    def train_step(self):
        device = self.device
@@ -328,6 +309,7 @@ class MuLaNTrainer(nn.Module):

        for _ in range(self.grad_accum_every):
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            loss = self.mulan(**data_kwargs)

            self.accelerator.backward(loss / self.grad_accum_every)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.12',
  version = '0.0.14',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',