Loading audiolm_pytorch/__init__.py +1 −1 Original line number Diff line number Diff line Loading @@ -7,4 +7,4 @@ from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransf from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.trainer import SoundStreamTrainer from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer audiolm_pytorch/audiolm_pytorch.py +8 −0 Original line number Diff line number Diff line Loading @@ -348,6 +348,14 @@ class SemanticTransformer(nn.Module): self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) def non_wav2vec_parameters(self): return ( set([*self.semantic_embedding.parameters()]) | set([self.start_token]) | set([*self.transformer.parameters()]) | set([*self.to_logits.parameters()]) ) @property def device(self): return next(self.parameters()).device Loading audiolm_pytorch/trainer.py +197 −0 Original line number Diff line number Diff line Loading @@ -5,6 +5,9 @@ from pathlib import Path from shutil import rmtree from PIL import Image from typing import Union, List, Optional from typeguard import typechecked import torch import torchaudio from torch import nn Loading @@ -21,6 +24,17 @@ from audiolm_pytorch.optimizer import get_optimizer from ema_pytorch import EMA from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.audiolm_pytorch import ( SemanticTransformer, CoarseTransformer, CoarseTransformerWrapper, FineTransformer, FineTransformerWrapper, FairseqVQWav2Vec, HubertWithKmeans ) from audiolm_pytorch.data import SoundDataset, get_dataloader from accelerate import Accelerator Loading Loading @@ -320,3 +334,186 @@ class SoundStreamTrainer(nn.Module): log_fn(logs) self.print('training complete') # semantic transformer trainer @typechecked class SemanticTransformerTrainer(nn.Module): def __init__( self, transformer: SemanticTransformer, *, num_train_steps, batch_size, data_max_length = None, folder, lr = 3e-4, grad_accum_every = 1, wd = 0., max_grad_norm = 0.5, valid_frac = 0.05, random_split_seed = 42, save_results_every = 100, save_model_every = 1000, results_folder = './results', accelerate_kwargs: dict = dict() ): super().__init__() self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.register_buffer('steps', torch.Tensor([0])) self.num_train_steps = num_train_steps self.batch_size = batch_size self.grad_accum_every = grad_accum_every # optimizers self.optim = get_optimizer(transformer.non_wav2vec_parameters(), lr = lr, wd = wd) # max grad norm self.max_grad_norm = max_grad_norm # create dataset self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = transformer.wav2vec.target_sample_hz, seq_len_multiple_of = transformer.wav2vec.seq_len_multiple_of ) # split for validation if valid_frac > 0: train_size = int((1 - valid_frac) * len(self.ds)) valid_size = len(self.ds) - train_size self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') else: self.valid_ds = self.ds self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') # dataloader self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True) self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True) # prepare with accelerator ( self.transformer, self.optim, self.dl, self.valid_dl ) = self.accelerator.prepare( self.transformer, self.optim, self.dl, self.valid_dl ) # dataloader iterators self.dl_iter = cycle(self.dl) self.valid_dl_iter = cycle(self.valid_dl) self.save_model_every = save_model_every self.save_results_every = save_results_every self.results_folder = Path(results_folder) if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): rmtree(str(self.results_folder)) self.results_folder.mkdir(parents = True, exist_ok = True) def print(self, msg): self.accelerator.print(msg) @property def device(self): return self.accelerator.device @property def is_distributed(self): return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) @property def is_main(self): return self.accelerator.is_main_process @property def is_local_main(self): return self.accelerator.is_local_main_process def train_step(self): device = self.device steps = int(self.steps.item()) self.transformer.train() # logs logs = {} # update vae (generator) for _ in range(self.grad_accum_every): wave = next(self.dl_iter).to(device) loss = self.transformer(raw_wave = wave, return_loss = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(self.transformer.non_wav2vec_parameters(), self.max_grad_norm) self.optim.step() self.optim.zero_grad() # log self.print(f"{steps}: loss: {logs['loss']}") # sample results every so often if self.is_main and not (steps % self.save_results_every): model = self.transformer filename = str(steps) model.eval() wave = next(self.valid_dl_iter).to(device) with torch.no_grad(): valid_loss = model(raw_wave = wave, return_loss = True) self.print(f'{steps}: valid loss {valid_loss}') # save model every so often if self.is_main and not (steps % self.save_model_every): state_dict = self.transformer.state_dict() model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt') torch.save(state_dict, model_path) self.print(f'{steps}: saving model to {str(self.results_folder)}') self.steps += 1 return logs def train(self, log_fn = noop): while self.steps < self.num_train_steps: logs = self.train_step() log_fn(logs) self.print('training complete') setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.50', version = '0.0.51', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/__init__.py +1 −1 Original line number Diff line number Diff line Loading @@ -7,4 +7,4 @@ from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransf from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.trainer import SoundStreamTrainer from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer
audiolm_pytorch/audiolm_pytorch.py +8 −0 Original line number Diff line number Diff line Loading @@ -348,6 +348,14 @@ class SemanticTransformer(nn.Module): self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) def non_wav2vec_parameters(self): return ( set([*self.semantic_embedding.parameters()]) | set([self.start_token]) | set([*self.transformer.parameters()]) | set([*self.to_logits.parameters()]) ) @property def device(self): return next(self.parameters()).device Loading
audiolm_pytorch/trainer.py +197 −0 Original line number Diff line number Diff line Loading @@ -5,6 +5,9 @@ from pathlib import Path from shutil import rmtree from PIL import Image from typing import Union, List, Optional from typeguard import typechecked import torch import torchaudio from torch import nn Loading @@ -21,6 +24,17 @@ from audiolm_pytorch.optimizer import get_optimizer from ema_pytorch import EMA from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.audiolm_pytorch import ( SemanticTransformer, CoarseTransformer, CoarseTransformerWrapper, FineTransformer, FineTransformerWrapper, FairseqVQWav2Vec, HubertWithKmeans ) from audiolm_pytorch.data import SoundDataset, get_dataloader from accelerate import Accelerator Loading Loading @@ -320,3 +334,186 @@ class SoundStreamTrainer(nn.Module): log_fn(logs) self.print('training complete') # semantic transformer trainer @typechecked class SemanticTransformerTrainer(nn.Module): def __init__( self, transformer: SemanticTransformer, *, num_train_steps, batch_size, data_max_length = None, folder, lr = 3e-4, grad_accum_every = 1, wd = 0., max_grad_norm = 0.5, valid_frac = 0.05, random_split_seed = 42, save_results_every = 100, save_model_every = 1000, results_folder = './results', accelerate_kwargs: dict = dict() ): super().__init__() self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.register_buffer('steps', torch.Tensor([0])) self.num_train_steps = num_train_steps self.batch_size = batch_size self.grad_accum_every = grad_accum_every # optimizers self.optim = get_optimizer(transformer.non_wav2vec_parameters(), lr = lr, wd = wd) # max grad norm self.max_grad_norm = max_grad_norm # create dataset self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = transformer.wav2vec.target_sample_hz, seq_len_multiple_of = transformer.wav2vec.seq_len_multiple_of ) # split for validation if valid_frac > 0: train_size = int((1 - valid_frac) * len(self.ds)) valid_size = len(self.ds) - train_size self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') else: self.valid_ds = self.ds self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') # dataloader self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True) self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True) # prepare with accelerator ( self.transformer, self.optim, self.dl, self.valid_dl ) = self.accelerator.prepare( self.transformer, self.optim, self.dl, self.valid_dl ) # dataloader iterators self.dl_iter = cycle(self.dl) self.valid_dl_iter = cycle(self.valid_dl) self.save_model_every = save_model_every self.save_results_every = save_results_every self.results_folder = Path(results_folder) if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): rmtree(str(self.results_folder)) self.results_folder.mkdir(parents = True, exist_ok = True) def print(self, msg): self.accelerator.print(msg) @property def device(self): return self.accelerator.device @property def is_distributed(self): return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) @property def is_main(self): return self.accelerator.is_main_process @property def is_local_main(self): return self.accelerator.is_local_main_process def train_step(self): device = self.device steps = int(self.steps.item()) self.transformer.train() # logs logs = {} # update vae (generator) for _ in range(self.grad_accum_every): wave = next(self.dl_iter).to(device) loss = self.transformer(raw_wave = wave, return_loss = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(self.transformer.non_wav2vec_parameters(), self.max_grad_norm) self.optim.step() self.optim.zero_grad() # log self.print(f"{steps}: loss: {logs['loss']}") # sample results every so often if self.is_main and not (steps % self.save_results_every): model = self.transformer filename = str(steps) model.eval() wave = next(self.valid_dl_iter).to(device) with torch.no_grad(): valid_loss = model(raw_wave = wave, return_loss = True) self.print(f'{steps}: valid loss {valid_loss}') # save model every so often if self.is_main and not (steps % self.save_model_every): state_dict = self.transformer.state_dict() model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt') torch.save(state_dict, model_path) self.print(f'{steps}: saving model to {str(self.results_folder)}') self.steps += 1 return logs def train(self, log_fn = noop): while self.steps < self.num_train_steps: logs = self.train_step() log_fn(logs) self.print('training complete')
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.50', version = '0.0.51', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading