Commit e50305de authored by Phil Wang's avatar Phil Wang
Browse files

move beartype onto methods

parent 2d13bf13
Loading
Loading
Loading
Loading
+8 −6
Original line number Diff line number Diff line
@@ -461,8 +461,8 @@ class Transformer(nn.Module):

# the three hierarchical transformers

@beartype
class SemanticTransformer(nn.Module):
    @beartype
    def __init__(
        self,
        *,
@@ -532,6 +532,7 @@ class SemanticTransformer(nn.Module):
        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    @beartype
    def forward(
        self,
        *,
@@ -580,8 +581,8 @@ class SemanticTransformer(nn.Module):
        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)
        return self.to_logits(tokens)

@beartype
class CoarseTransformer(nn.Module):
    @beartype
    def __init__(
        self,
        *,
@@ -673,6 +674,7 @@ class CoarseTransformer(nn.Module):
        scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale
        return scaled_semantic_logits, scaled_coarse_logits

    @beartype
    def forward(
        self,
        *,
@@ -1089,8 +1091,8 @@ class FineTransformer(nn.Module):

# training wrappers

@beartype
class SemanticTransformerWrapper(nn.Module):
    @beartype
    def __init__(
        self,
        *,
@@ -1263,8 +1265,8 @@ class SemanticTransformerWrapper(nn.Module):

        return loss

@beartype
class CoarseTransformerWrapper(nn.Module):
    @beartype
    def __init__(
        self,
        *,
@@ -1484,8 +1486,8 @@ class CoarseTransformerWrapper(nn.Module):
            coarse_loss * num_coarse_logits
        ) / (num_semantic_logits + num_coarse_logits)

@beartype
class FineTransformerWrapper(nn.Module):
    @beartype
    def __init__(
        self,
        *,
@@ -1711,8 +1713,8 @@ class FineTransformerWrapper(nn.Module):

# audio LM

@beartype
class AudioLM(nn.Module):
    @beartype
    def __init__(
        self,
        *,
+1 −1
Original line number Diff line number Diff line
@@ -31,8 +31,8 @@ OptionalIntOrTupleInt = Optional[Union[int, Tuple[Optional[int], ...]]]

# dataset functions

@beartype
class SoundDataset(Dataset):
    @beartype
    def __init__(
        self,
        folder,
+4 −3
Original line number Diff line number Diff line
@@ -114,6 +114,7 @@ def determine_types(data, config):
# main trainer class

class SoundStreamTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        soundstream: SoundStream,
@@ -524,8 +525,8 @@ class SoundStreamTrainer(nn.Module):

# semantic transformer trainer

@beartype
class SemanticTransformerTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
@@ -763,8 +764,8 @@ class SemanticTransformerTrainer(nn.Module):

# fine transformer trainer

@beartype
class CoarseTransformerTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        transformer: CoarseTransformer,
@@ -1013,8 +1014,8 @@ class CoarseTransformerTrainer(nn.Module):

# fine transformer trainer

@beartype
class FineTransformerTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        transformer: FineTransformer,
+1 −1
Original line number Diff line number Diff line
__version__ = '0.26.1'
__version__ = '0.26.2'