Commit 31edfade authored by Phil Wang's avatar Phil Wang
Browse files

first pass at fine transformer trainer

parent 28e12b7e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -7,4 +7,4 @@ from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransf
from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer
from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer
+190 −0
Original line number Diff line number Diff line
@@ -517,3 +517,193 @@ class SemanticTransformerTrainer(nn.Module):
            log_fn(logs)

        self.print('training complete')

# semantic transformer trainer

@typechecked
class FineTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer: FineTransformer,
        soundstream: SoundStream,
        *,
        num_train_steps,
        batch_size,
        data_max_length = None,
        folder,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict()
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.transformer = transformer
        self.soundstream = soundstream

        self.train_wrapper = FineTransformerWrapper(
            soundstream = soundstream,
            transformer = transformer
        )

        self.register_buffer('steps', torch.Tensor([0]))

        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # optimizers

        self.optim = get_optimizer(transformer.parameters(), lr = lr, wd = wd)

        # max grad norm

        self.max_grad_norm = max_grad_norm

        # create dataset

        self.ds = SoundDataset(
            folder,
            max_length = data_max_length,
            target_sample_hz = soundstream.target_sample_hz,
            seq_len_multiple_of = soundstream.seq_len_multiple_of
        )

        # split for validation

        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # dataloader

        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True)

        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True)

        # prepare with accelerator

        (
            self.transformer,
            self.optim,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.transformer,
            self.optim,
            self.dl,
            self.valid_dl
        )

        # dataloader iterators

        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        self.save_model_every = save_model_every
        self.save_results_every = save_results_every    

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)

        self.train_wrapper.to(self.device)

    def print(self, msg):
        self.accelerator.print(msg)

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    def train_step(self):
        device = self.device

        steps = int(self.steps.item())

        self.transformer.train()

        # logs

        logs = {}

        # update vae (generator)

        for _ in range(self.grad_accum_every):
            wave = next(self.dl_iter).to(device)
            loss = self.train_wrapper(raw_wave = wave, return_loss = True)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # log

        self.print(f"{steps}: loss: {logs['loss']}")

        # sample results every so often

        if self.is_main and not (steps % self.save_results_every):
            filename = str(steps)

            self.train_wrapper.eval()

            wave = next(self.valid_dl_iter).to(device)

            with torch.no_grad():
                valid_loss = self.train_wrapper(raw_wave = wave, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')

        # save model every so often

        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.transformer.state_dict()
            model_path = str(self.results_folder / f'fine.transformer.{steps}.pt')
            torch.save(state_dict, model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    def train(self, log_fn = noop):

        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.51',
  version = '0.0.52',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',