Loading audiolm_pytorch/audiolm_pytorch.py +7 −7 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ import math from functools import partial from typing import Optional, Union, List from typeguard import typechecked from beartype import beartype import torch from torch import nn, einsum Loading Loading @@ -309,7 +309,7 @@ class Transformer(nn.Module): # the three hierarchical transformers @typechecked @beartype class SemanticTransformer(nn.Module): def __init__( self, Loading Loading @@ -398,7 +398,7 @@ class SemanticTransformer(nn.Module): tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask) return self.to_logits(tokens) @typechecked @beartype class CoarseTransformer(nn.Module): def __init__( self, Loading Loading @@ -711,7 +711,7 @@ class FineTransformer(nn.Module): # training wrappers @typechecked @beartype class SemanticTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -860,7 +860,7 @@ class SemanticTransformerWrapper(nn.Module): return loss @typechecked @beartype class CoarseTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -1046,7 +1046,7 @@ class CoarseTransformerWrapper(nn.Module): coarse_loss * num_coarse_logits ) / (num_semantic_logits + num_coarse_logits) @typechecked @beartype class FineTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -1236,7 +1236,7 @@ class FineTransformerWrapper(nn.Module): # audio LM @typechecked @beartype class AudioLM(nn.Module): def __init__( self, Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -19,6 +19,7 @@ setup( ], install_requires=[ 'accelerate', 'beartype', 'einops>=0.6', 'ema-pytorch', 'fairseq', Loading @@ -29,7 +30,6 @@ setup( 'torchaudio', 'transformers', 'tqdm', 'typeguard', 'vector-quantize-pytorch>=0.10.11' ], classifiers=[ Loading Loading
audiolm_pytorch/audiolm_pytorch.py +7 −7 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ import math from functools import partial from typing import Optional, Union, List from typeguard import typechecked from beartype import beartype import torch from torch import nn, einsum Loading Loading @@ -309,7 +309,7 @@ class Transformer(nn.Module): # the three hierarchical transformers @typechecked @beartype class SemanticTransformer(nn.Module): def __init__( self, Loading Loading @@ -398,7 +398,7 @@ class SemanticTransformer(nn.Module): tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask) return self.to_logits(tokens) @typechecked @beartype class CoarseTransformer(nn.Module): def __init__( self, Loading Loading @@ -711,7 +711,7 @@ class FineTransformer(nn.Module): # training wrappers @typechecked @beartype class SemanticTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -860,7 +860,7 @@ class SemanticTransformerWrapper(nn.Module): return loss @typechecked @beartype class CoarseTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -1046,7 +1046,7 @@ class CoarseTransformerWrapper(nn.Module): coarse_loss * num_coarse_logits ) / (num_semantic_logits + num_coarse_logits) @typechecked @beartype class FineTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -1236,7 +1236,7 @@ class FineTransformerWrapper(nn.Module): # audio LM @typechecked @beartype class AudioLM(nn.Module): def __init__( self, Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -19,6 +19,7 @@ setup( ], install_requires=[ 'accelerate', 'beartype', 'einops>=0.6', 'ema-pytorch', 'fairseq', Loading @@ -29,7 +30,6 @@ setup( 'torchaudio', 'transformers', 'tqdm', 'typeguard', 'vector-quantize-pytorch>=0.10.11' ], classifiers=[ Loading