Commit 87e835e3 authored by Phil Wang's avatar Phil Wang
Browse files

move beartype onto methods

parent 7422185e
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -503,8 +503,8 @@ class AudioSpectrogramTransformer(nn.Module):

# text transformer

@beartype
class TextTransformer(nn.Module):
    @beartype
    def __init__(
        self,
        dim,
@@ -546,6 +546,7 @@ class TextTransformer(nn.Module):
    def device(self):
        return next(self.parameters()).device

    @beartype
    def forward(
        self,
        x = None,
@@ -648,8 +649,8 @@ class MultiLayerContrastiveLoss(nn.Module):

# main classes

@beartype
class MuLaN(nn.Module):
    @beartype
    def __init__(
        self,
        audio_transformer: AudioSpectrogramTransformer,
@@ -705,6 +706,7 @@ class MuLaN(nn.Module):

        return out, audio_layers

    @beartype
    def get_text_latents(
        self,
        texts = None,
@@ -720,6 +722,7 @@ class MuLaN(nn.Module):

        return out, text_layers

    @beartype
    def forward(
        self,
        wavs,
@@ -766,8 +769,8 @@ class MuLaN(nn.Module):

# music lm

@beartype
class MuLaNEmbedQuantizer(AudioConditionerBase):
    @beartype
    def __init__(
        self,
        mulan: MuLaN,
@@ -851,8 +854,8 @@ class MuLaNEmbedQuantizer(AudioConditionerBase):
        cond_embeddings = cond_embeddings.gather(2, indices)
        return rearrange(cond_embeddings, 'b q 1 d -> b q d')

@beartype
class MusicLM(nn.Module):
    @beartype
    def __init__(
        self,
        audio_lm: AudioLM,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.2.1',
  version = '0.2.2',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',