Commit 04bed934 authored by Phil Wang's avatar Phil Wang
Browse files

use a base class to enforce what the trainer can accept, for musiclm

parent 88bf4dfb
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from audiolm_pytorch.audiolm_pytorch import AudioLM
from audiolm_pytorch.soundstream import SoundStream

from audiolm_pytorch.audiolm_pytorch import SemanticBase, CoarseBase, FineBase
from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper, SemanticTransformerWrapper

+14 −3
Original line number Diff line number Diff line
@@ -424,10 +424,21 @@ class Transformer(nn.Module):

        return self.norm(x)

# bases

class SemanticBase(nn.Module):
    pass

class CoarseBase(nn.Module):
    pass

class FineBase(nn.Module):
    pass

# the three hierarchical transformers

@beartype
class SemanticTransformer(nn.Module):
class SemanticTransformer(SemanticBase):
    def __init__(
        self,
        *,
@@ -542,7 +553,7 @@ class SemanticTransformer(nn.Module):
        return self.to_logits(tokens)

@beartype
class CoarseTransformer(nn.Module):
class CoarseTransformer(CoarseBase):
    def __init__(
        self,
        *,
@@ -710,7 +721,7 @@ class CoarseTransformer(nn.Module):

        return semantic_logits, coarse_logits

class FineTransformer(nn.Module):
class FineTransformer(FineBase):
    def __init__(
        self,
        *,
+6 −3
Original line number Diff line number Diff line
@@ -25,6 +25,9 @@ from ema_pytorch import EMA
from audiolm_pytorch.soundstream import SoundStream

from audiolm_pytorch.audiolm_pytorch import (
    SemanticBase,
    CoarseBase,
    FineBase,
    SemanticTransformer,
    SemanticTransformerWrapper,
    CoarseTransformer,
@@ -426,7 +429,7 @@ class SemanticTransformerTrainer(nn.Module):
    def __init__(
        self,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        transformer,
        transformer: SemanticBase,
        *,
        num_train_steps,
        batch_size,
@@ -649,7 +652,7 @@ class SemanticTransformerTrainer(nn.Module):
class CoarseTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer,
        transformer: CoarseBase,
        soundstream: SoundStream,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        *,
@@ -884,7 +887,7 @@ class CoarseTransformerTrainer(nn.Module):
class FineTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer,
        transformer: FineBase,
        soundstream: SoundStream,
        *,
        num_train_steps,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.8.2',
  version = '0.8.3',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',