Loading audiolm_pytorch/audiolm_pytorch.py +8 −6 Original line number Diff line number Diff line Loading @@ -461,8 +461,8 @@ class Transformer(nn.Module): # the three hierarchical transformers @beartype class SemanticTransformer(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -532,6 +532,7 @@ class SemanticTransformer(nn.Module): null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale @beartype def forward( self, *, Loading Loading @@ -580,8 +581,8 @@ class SemanticTransformer(nn.Module): tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask) return self.to_logits(tokens) @beartype class CoarseTransformer(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -673,6 +674,7 @@ class CoarseTransformer(nn.Module): scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale return scaled_semantic_logits, scaled_coarse_logits @beartype def forward( self, *, Loading Loading @@ -1089,8 +1091,8 @@ class FineTransformer(nn.Module): # training wrappers @beartype class SemanticTransformerWrapper(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -1263,8 +1265,8 @@ class SemanticTransformerWrapper(nn.Module): return loss @beartype class CoarseTransformerWrapper(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -1484,8 +1486,8 @@ class CoarseTransformerWrapper(nn.Module): coarse_loss * num_coarse_logits ) / (num_semantic_logits + num_coarse_logits) @beartype class FineTransformerWrapper(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -1711,8 +1713,8 @@ class FineTransformerWrapper(nn.Module): # audio LM @beartype class AudioLM(nn.Module): @beartype def __init__( self, *, Loading audiolm_pytorch/data.py +1 −1 Original line number Diff line number Diff line Loading @@ -31,8 +31,8 @@ OptionalIntOrTupleInt = Optional[Union[int, Tuple[Optional[int], ...]]] # dataset functions @beartype class SoundDataset(Dataset): @beartype def __init__( self, folder, Loading audiolm_pytorch/trainer.py +4 −3 Original line number Diff line number Diff line Loading @@ -114,6 +114,7 @@ def determine_types(data, config): # main trainer class class SoundStreamTrainer(nn.Module): @beartype def __init__( self, soundstream: SoundStream, Loading Loading @@ -524,8 +525,8 @@ class SoundStreamTrainer(nn.Module): # semantic transformer trainer @beartype class SemanticTransformerTrainer(nn.Module): @beartype def __init__( self, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], Loading Loading @@ -763,8 +764,8 @@ class SemanticTransformerTrainer(nn.Module): # fine transformer trainer @beartype class CoarseTransformerTrainer(nn.Module): @beartype def __init__( self, transformer: CoarseTransformer, Loading Loading @@ -1013,8 +1014,8 @@ class CoarseTransformerTrainer(nn.Module): # fine transformer trainer @beartype class FineTransformerTrainer(nn.Module): @beartype def __init__( self, transformer: FineTransformer, Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.26.1' __version__ = '0.26.2' Loading
audiolm_pytorch/audiolm_pytorch.py +8 −6 Original line number Diff line number Diff line Loading @@ -461,8 +461,8 @@ class Transformer(nn.Module): # the three hierarchical transformers @beartype class SemanticTransformer(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -532,6 +532,7 @@ class SemanticTransformer(nn.Module): null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale @beartype def forward( self, *, Loading Loading @@ -580,8 +581,8 @@ class SemanticTransformer(nn.Module): tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask) return self.to_logits(tokens) @beartype class CoarseTransformer(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -673,6 +674,7 @@ class CoarseTransformer(nn.Module): scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale return scaled_semantic_logits, scaled_coarse_logits @beartype def forward( self, *, Loading Loading @@ -1089,8 +1091,8 @@ class FineTransformer(nn.Module): # training wrappers @beartype class SemanticTransformerWrapper(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -1263,8 +1265,8 @@ class SemanticTransformerWrapper(nn.Module): return loss @beartype class CoarseTransformerWrapper(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -1484,8 +1486,8 @@ class CoarseTransformerWrapper(nn.Module): coarse_loss * num_coarse_logits ) / (num_semantic_logits + num_coarse_logits) @beartype class FineTransformerWrapper(nn.Module): @beartype def __init__( self, *, Loading Loading @@ -1711,8 +1713,8 @@ class FineTransformerWrapper(nn.Module): # audio LM @beartype class AudioLM(nn.Module): @beartype def __init__( self, *, Loading
audiolm_pytorch/data.py +1 −1 Original line number Diff line number Diff line Loading @@ -31,8 +31,8 @@ OptionalIntOrTupleInt = Optional[Union[int, Tuple[Optional[int], ...]]] # dataset functions @beartype class SoundDataset(Dataset): @beartype def __init__( self, folder, Loading
audiolm_pytorch/trainer.py +4 −3 Original line number Diff line number Diff line Loading @@ -114,6 +114,7 @@ def determine_types(data, config): # main trainer class class SoundStreamTrainer(nn.Module): @beartype def __init__( self, soundstream: SoundStream, Loading Loading @@ -524,8 +525,8 @@ class SoundStreamTrainer(nn.Module): # semantic transformer trainer @beartype class SemanticTransformerTrainer(nn.Module): @beartype def __init__( self, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], Loading Loading @@ -763,8 +764,8 @@ class SemanticTransformerTrainer(nn.Module): # fine transformer trainer @beartype class CoarseTransformerTrainer(nn.Module): @beartype def __init__( self, transformer: CoarseTransformer, Loading Loading @@ -1013,8 +1014,8 @@ class CoarseTransformerTrainer(nn.Module): # fine transformer trainer @beartype class FineTransformerTrainer(nn.Module): @beartype def __init__( self, transformer: FineTransformer, Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.26.1' __version__ = '0.26.2'