Commit 51cef900 authored by Phil Wang's avatar Phil Wang
Browse files

make sure soundstream training runs with accelerate

parent dea0e2fe
Loading
Loading
Loading
Loading
+19 −11
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ from audiolm_pytorch.data import SoundDataset, get_dataloader
from audiolm_pytorch.utils import AudioConditionerBase

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

# constants

@@ -137,7 +138,9 @@ class SoundStreamTrainer(nn.Module):
        force_clear_prev_results = None  # set to True | False to skip the prompt
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)

        kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
        self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs)

        self.soundstream = soundstream
        self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every)
@@ -195,14 +198,12 @@ class SoundStreamTrainer(nn.Module):
            self.soundstream,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
            self.dl
        ) = self.accelerator.prepare(
            self.soundstream,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
            self.dl
        )

        # prepare the multiscale discriminators with accelerator
@@ -224,7 +225,7 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder = Path(results_folder)

        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
@@ -377,30 +378,37 @@ class SoundStreamTrainer(nn.Module):

        # update exponential moving averaged generator

        self.accelerator.wait_for_everyone()

        if self.is_main:
            self.ema_soundstream.update()

        # sample results every so often

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.soundstream, str(steps))):
            for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))):
                model.eval()

                wave, = next(self.valid_dl_iter)
                wave = wave.to(device)

                with torch.no_grad():
                    recons = model(wave, return_recons_only = True)

                milestone = steps // self.save_results_every

                for ind, recon in enumerate(recons.unbind(dim = 0)):
                    filename = str(self.results_folder / f'sample_{steps}.flac')
                    torchaudio.save(filename, recon.cpu().detach(), self.soundstream.target_sample_hz)
                    torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz)

            self.print(f'{steps}: saving to {str(self.results_folder)}')

        # save model every so often

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'soundstream.{steps}.pt')
            self.save(model_path)
@@ -528,7 +536,7 @@ class SemanticTransformerTrainer(nn.Module):

        self.results_folder = Path(results_folder)

        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
@@ -763,7 +771,7 @@ class CoarseTransformerTrainer(nn.Module):

        self.results_folder = Path(results_folder)

        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.12.4',
  version = '0.12.5',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',