Commit 58f13b44 authored by Phil Wang's avatar Phil Wang
Browse files

be able to set dataloader num workers on soundstream trainer

parent 2f4407b4
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -130,6 +130,7 @@ class SoundStreamTrainer(nn.Module):
        ema_update_after_step = 500,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        dl_num_workers = 0,
        accelerate_kwargs: dict = dict()
    ):
        super().__init__()
@@ -181,9 +182,9 @@ class SoundStreamTrainer(nn.Module):

        # dataloader

        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True)
        self.dl = get_dataloader(self.ds, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True)

        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True)
        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True)

        # prepare with accelerator

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.4.6',
  version = '0.4.7',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',