Loading audiolm_pytorch/optimizer.py +11 −4 Original line number Diff line number Diff line from lion_pytorch import Lion from torch.optim import AdamW, Adam def separate_weight_decayable_params(params): Loading @@ -15,15 +16,15 @@ def get_optimizer( eps = 1e-8, filter_by_requires_grad = False, group_wd_params = True, use_lion = False, **kwargs ): has_wd = wd > 0 if filter_by_requires_grad: params = list(filter(lambda t: t.requires_grad, params)) if wd == 0: return Adam(params, lr = lr, betas = betas, eps = eps) if group_wd_params: if group_wd_params and has_wd: wd_params, no_wd_params = separate_weight_decayable_params(params) params = [ Loading @@ -31,4 +32,10 @@ def get_optimizer( {'params': no_wd_params, 'weight_decay': 0}, ] if use_lion: return Lion(params, lr = lr, betas = betas, weight_decay = wd) if not has_wd: return Adam(params, lr = lr, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) audiolm_pytorch/trainer.py +2 −1 Original line number Diff line number Diff line Loading @@ -133,6 +133,7 @@ class SoundStreamTrainer(nn.Module): apply_grad_penalty_every = 4, dl_num_workers = 0, accelerate_kwargs: dict = dict(), use_lion = False, force_clear_prev_results = None # set to True | False to skip the prompt ): super().__init__() Loading @@ -155,7 +156,7 @@ class SoundStreamTrainer(nn.Module): one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd) setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer) self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd) self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd, use_lion = use_lion) # max grad norm Loading setup.py +2 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.12.1', version = '0.12.2', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -24,6 +24,7 @@ setup( 'ema-pytorch', 'fairseq', 'joblib', 'lion-pytorch', 'local-attention>=1.6.0', 'Mega-pytorch', 'scikit-learn', Loading Loading
audiolm_pytorch/optimizer.py +11 −4 Original line number Diff line number Diff line from lion_pytorch import Lion from torch.optim import AdamW, Adam def separate_weight_decayable_params(params): Loading @@ -15,15 +16,15 @@ def get_optimizer( eps = 1e-8, filter_by_requires_grad = False, group_wd_params = True, use_lion = False, **kwargs ): has_wd = wd > 0 if filter_by_requires_grad: params = list(filter(lambda t: t.requires_grad, params)) if wd == 0: return Adam(params, lr = lr, betas = betas, eps = eps) if group_wd_params: if group_wd_params and has_wd: wd_params, no_wd_params = separate_weight_decayable_params(params) params = [ Loading @@ -31,4 +32,10 @@ def get_optimizer( {'params': no_wd_params, 'weight_decay': 0}, ] if use_lion: return Lion(params, lr = lr, betas = betas, weight_decay = wd) if not has_wd: return Adam(params, lr = lr, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
audiolm_pytorch/trainer.py +2 −1 Original line number Diff line number Diff line Loading @@ -133,6 +133,7 @@ class SoundStreamTrainer(nn.Module): apply_grad_penalty_every = 4, dl_num_workers = 0, accelerate_kwargs: dict = dict(), use_lion = False, force_clear_prev_results = None # set to True | False to skip the prompt ): super().__init__() Loading @@ -155,7 +156,7 @@ class SoundStreamTrainer(nn.Module): one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd) setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer) self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd) self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd, use_lion = use_lion) # max grad norm Loading
setup.py +2 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.12.1', version = '0.12.2', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -24,6 +24,7 @@ setup( 'ema-pytorch', 'fairseq', 'joblib', 'lion-pytorch', 'local-attention>=1.6.0', 'Mega-pytorch', 'scikit-learn', Loading