Commit 0d96c1af authored by Phil Wang's avatar Phil Wang
Browse files

add ability to use lion optimizer

parent aeb4962a
Loading
Loading
Loading
Loading
+11 −4
Original line number Diff line number Diff line
from lion_pytorch import Lion
from torch.optim import AdamW, Adam

def separate_weight_decayable_params(params):
@@ -15,15 +16,15 @@ def get_optimizer(
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    use_lion = False,
    **kwargs
):
    has_wd = wd > 0

    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    if group_wd_params:
    if group_wd_params and has_wd:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        params = [
@@ -31,4 +32,10 @@ def get_optimizer(
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    if use_lion:
        return Lion(params, lr = lr, betas = betas, weight_decay = wd)

    if not has_wd:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
+2 −1
Original line number Diff line number Diff line
@@ -133,6 +133,7 @@ class SoundStreamTrainer(nn.Module):
        apply_grad_penalty_every = 4,
        dl_num_workers = 0,
        accelerate_kwargs: dict = dict(),
        use_lion = False,
        force_clear_prev_results = None  # set to True | False to skip the prompt
    ):
        super().__init__()
@@ -155,7 +156,7 @@ class SoundStreamTrainer(nn.Module):
            one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd)
            setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer)

        self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd)
        self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd, use_lion = use_lion)

        # max grad norm

+2 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.12.1',
  version = '0.12.2',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -24,6 +24,7 @@ setup(
    'ema-pytorch',
    'fairseq',
    'joblib',
    'lion-pytorch',
    'local-attention>=1.6.0',
    'Mega-pytorch',
    'scikit-learn',