Commit 6b1f93eb authored by Phil Wang's avatar Phil Wang
Browse files

helper generate functions on trainer that forward to training wrapper

parent ddb40ee3
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -719,6 +719,7 @@ class FineTransformer(nn.Module):

# training wrappers

@typechecked
class SemanticTransformerWrapper(nn.Module):
    def __init__(
        self,
+9 −0
Original line number Diff line number Diff line
@@ -443,6 +443,9 @@ class SemanticTransformerTrainer(nn.Module):
    def print(self, msg):
        self.accelerator.print(msg)

    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    @property
    def device(self):
        return self.accelerator.device
@@ -637,6 +640,9 @@ class CoarseTransformerTrainer(nn.Module):
    def print(self, msg):
        self.accelerator.print(msg)

    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    @property
    def device(self):
        return self.accelerator.device
@@ -834,6 +840,9 @@ class FineTransformerTrainer(nn.Module):
    def print(self, msg):
        self.accelerator.print(msg)

    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    @property
    def device(self):
        return self.accelerator.device
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.53',
  version = '0.0.54',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',