Loading audiolm_pytorch/__init__.py +1 −1 Original line number Diff line number Diff line Loading @@ -7,4 +7,4 @@ from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransf from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer audiolm_pytorch/trainer.py +190 −0 Original line number Diff line number Diff line Loading @@ -517,3 +517,193 @@ class SemanticTransformerTrainer(nn.Module): log_fn(logs) self.print('training complete') # semantic transformer trainer @typechecked class FineTransformerTrainer(nn.Module): def __init__( self, transformer: FineTransformer, soundstream: SoundStream, *, num_train_steps, batch_size, data_max_length = None, folder, lr = 3e-4, grad_accum_every = 1, wd = 0., max_grad_norm = 0.5, valid_frac = 0.05, random_split_seed = 42, save_results_every = 100, save_model_every = 1000, results_folder = './results', accelerate_kwargs: dict = dict() ): super().__init__() self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.soundstream = soundstream self.train_wrapper = FineTransformerWrapper( soundstream = soundstream, transformer = transformer ) self.register_buffer('steps', torch.Tensor([0])) self.num_train_steps = num_train_steps self.batch_size = batch_size self.grad_accum_every = grad_accum_every # optimizers self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd) # max grad norm self.max_grad_norm = max_grad_norm # create dataset self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = soundstream.target_sample_hz, seq_len_multiple_of = soundstream.seq_len_multiple_of ) # split for validation if valid_frac > 0: train_size = int((1 - valid_frac) * len(self.ds)) valid_size = len(self.ds) - train_size self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') else: self.valid_ds = self.ds self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') # dataloader self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True) self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True) # prepare with accelerator ( self.transformer, self.optim, self.dl, self.valid_dl ) = self.accelerator.prepare( self.transformer, self.optim, self.dl, self.valid_dl ) # dataloader iterators self.dl_iter = cycle(self.dl) self.valid_dl_iter = cycle(self.valid_dl) self.save_model_every = save_model_every self.save_results_every = save_results_every self.results_folder = Path(results_folder) if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): rmtree(str(self.results_folder)) self.results_folder.mkdir(parents = True, exist_ok = True) self.train_wrapper.to(self.device) def print(self, msg): self.accelerator.print(msg) @property def device(self): return self.accelerator.device @property def is_distributed(self): return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) @property def is_main(self): return self.accelerator.is_main_process @property def is_local_main(self): return self.accelerator.is_local_main_process def train_step(self): device = self.device steps = int(self.steps.item()) self.transformer.train() # logs logs = {} # update vae (generator) for _ in range(self.grad_accum_every): wave = next(self.dl_iter).to(device) loss = self.train_wrapper(raw_wave = wave, return_loss = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm) self.optim.step() self.optim.zero_grad() # log self.print(f"{steps}: loss: {logs['loss']}") # sample results every so often if self.is_main and not (steps % self.save_results_every): filename = str(steps) self.train_wrapper.eval() wave = next(self.valid_dl_iter).to(device) with torch.no_grad(): valid_loss = self.train_wrapper(raw_wave = wave, return_loss = True) self.print(f'{steps}: valid loss {valid_loss}') # save model every so often if self.is_main and not (steps % self.save_model_every): state_dict = self.transformer.state_dict() model_path = str(self.results_folder / f'fine.transformer.{steps}.pt') torch.save(state_dict, model_path) self.print(f'{steps}: saving model to {str(self.results_folder)}') self.steps += 1 return logs def train(self, log_fn = noop): while self.steps < self.num_train_steps: logs = self.train_step() log_fn(logs) self.print('training complete') setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.51', version = '0.0.52', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/__init__.py +1 −1 Original line number Diff line number Diff line Loading @@ -7,4 +7,4 @@ from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransf from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer
audiolm_pytorch/trainer.py +190 −0 Original line number Diff line number Diff line Loading @@ -517,3 +517,193 @@ class SemanticTransformerTrainer(nn.Module): log_fn(logs) self.print('training complete') # semantic transformer trainer @typechecked class FineTransformerTrainer(nn.Module): def __init__( self, transformer: FineTransformer, soundstream: SoundStream, *, num_train_steps, batch_size, data_max_length = None, folder, lr = 3e-4, grad_accum_every = 1, wd = 0., max_grad_norm = 0.5, valid_frac = 0.05, random_split_seed = 42, save_results_every = 100, save_model_every = 1000, results_folder = './results', accelerate_kwargs: dict = dict() ): super().__init__() self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.soundstream = soundstream self.train_wrapper = FineTransformerWrapper( soundstream = soundstream, transformer = transformer ) self.register_buffer('steps', torch.Tensor([0])) self.num_train_steps = num_train_steps self.batch_size = batch_size self.grad_accum_every = grad_accum_every # optimizers self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd) # max grad norm self.max_grad_norm = max_grad_norm # create dataset self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = soundstream.target_sample_hz, seq_len_multiple_of = soundstream.seq_len_multiple_of ) # split for validation if valid_frac > 0: train_size = int((1 - valid_frac) * len(self.ds)) valid_size = len(self.ds) - train_size self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') else: self.valid_ds = self.ds self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') # dataloader self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True) self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True) # prepare with accelerator ( self.transformer, self.optim, self.dl, self.valid_dl ) = self.accelerator.prepare( self.transformer, self.optim, self.dl, self.valid_dl ) # dataloader iterators self.dl_iter = cycle(self.dl) self.valid_dl_iter = cycle(self.valid_dl) self.save_model_every = save_model_every self.save_results_every = save_results_every self.results_folder = Path(results_folder) if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): rmtree(str(self.results_folder)) self.results_folder.mkdir(parents = True, exist_ok = True) self.train_wrapper.to(self.device) def print(self, msg): self.accelerator.print(msg) @property def device(self): return self.accelerator.device @property def is_distributed(self): return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) @property def is_main(self): return self.accelerator.is_main_process @property def is_local_main(self): return self.accelerator.is_local_main_process def train_step(self): device = self.device steps = int(self.steps.item()) self.transformer.train() # logs logs = {} # update vae (generator) for _ in range(self.grad_accum_every): wave = next(self.dl_iter).to(device) loss = self.train_wrapper(raw_wave = wave, return_loss = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm) self.optim.step() self.optim.zero_grad() # log self.print(f"{steps}: loss: {logs['loss']}") # sample results every so often if self.is_main and not (steps % self.save_results_every): filename = str(steps) self.train_wrapper.eval() wave = next(self.valid_dl_iter).to(device) with torch.no_grad(): valid_loss = self.train_wrapper(raw_wave = wave, return_loss = True) self.print(f'{steps}: valid loss {valid_loss}') # save model every so often if self.is_main and not (steps % self.save_model_every): state_dict = self.transformer.state_dict() model_path = str(self.results_folder / f'fine.transformer.{steps}.pt') torch.save(state_dict, model_path) self.print(f'{steps}: saving model to {str(self.results_folder)}') self.steps += 1 return logs def train(self, log_fn = noop): while self.steps < self.num_train_steps: logs = self.train_step() log_fn(logs) self.print('training complete')
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.51', version = '0.0.52', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading