Commit dad029e8 authored by Phil Wang's avatar Phil Wang
Browse files

switch to beartype

parent 74c93f5a
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import math
from functools import partial

from typing import Optional, Union, List
from typeguard import typechecked
from beartype import beartype

import torch
from torch import nn, einsum
@@ -309,7 +309,7 @@ class Transformer(nn.Module):

# the three hierarchical transformers

@typechecked
@beartype
class SemanticTransformer(nn.Module):
    def __init__(
        self,
@@ -398,7 +398,7 @@ class SemanticTransformer(nn.Module):
        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)
        return self.to_logits(tokens)

@typechecked
@beartype
class CoarseTransformer(nn.Module):
    def __init__(
        self,
@@ -711,7 +711,7 @@ class FineTransformer(nn.Module):

# training wrappers

@typechecked
@beartype
class SemanticTransformerWrapper(nn.Module):
    def __init__(
        self,
@@ -860,7 +860,7 @@ class SemanticTransformerWrapper(nn.Module):

        return loss

@typechecked
@beartype
class CoarseTransformerWrapper(nn.Module):
    def __init__(
        self,
@@ -1046,7 +1046,7 @@ class CoarseTransformerWrapper(nn.Module):
            coarse_loss * num_coarse_logits
        ) / (num_semantic_logits + num_coarse_logits)

@typechecked
@beartype
class FineTransformerWrapper(nn.Module):
    def __init__(
        self,
@@ -1236,7 +1236,7 @@ class FineTransformerWrapper(nn.Module):

# audio LM

@typechecked
@beartype
class AudioLM(nn.Module):
    def __init__(
        self,
+1 −1
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ setup(
  ],
  install_requires=[
    'accelerate',
    'beartype',
    'einops>=0.6',
    'ema-pytorch',
    'fairseq',
@@ -29,7 +30,6 @@ setup(
    'torchaudio',
    'transformers',
    'tqdm',
    'typeguard',
    'vector-quantize-pytorch>=0.10.11'
  ],
  classifiers=[