Loading audiolm_pytorch/audiolm_pytorch.py +13 −6 Original line number Diff line number Diff line Loading @@ -283,6 +283,7 @@ class Transformer(nn.Module): *, dim, depth, heads, dim_context = None, cross_attend = False, grad_shrink_alpha = 0.1, Loading @@ -293,12 +294,12 @@ class Transformer(nn.Module): self.layers = nn.ModuleList([]) self.rel_pos_bias = RelativePositionBias() self.rel_pos_bias = RelativePositionBias(heads = heads) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, causal = True, **kwargs), Attention(dim = dim, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, Attention(dim = dim, heads = heads, causal = True, **kwargs), Attention(dim = dim, heads = heads, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, FeedForward(dim = dim) ])) Loading Loading @@ -337,6 +338,8 @@ class SemanticTransformer(nn.Module): self, *, dim, depth, heads, num_semantic_tokens, t5_name = DEFAULT_T5_NAME, has_condition = False, Loading @@ -358,7 +361,7 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property Loading Loading @@ -428,6 +431,8 @@ class CoarseTransformer(nn.Module): codebook_size, num_coarse_quantizers, dim, depth, heads, num_semantic_tokens, t5_name = DEFAULT_T5_NAME, has_condition = False, Loading @@ -450,7 +455,7 @@ class CoarseTransformer(nn.Module): codebook_size_with_eos = codebook_size + 1 self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim) self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading Loading @@ -574,6 +579,8 @@ class FineTransformer(nn.Module): num_fine_quantizers, codebook_size, dim, depth, heads, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -597,7 +604,7 @@ class FineTransformer(nn.Module): self.eos_id = codebook_size self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading Loading
audiolm_pytorch/audiolm_pytorch.py +13 −6 Original line number Diff line number Diff line Loading @@ -283,6 +283,7 @@ class Transformer(nn.Module): *, dim, depth, heads, dim_context = None, cross_attend = False, grad_shrink_alpha = 0.1, Loading @@ -293,12 +294,12 @@ class Transformer(nn.Module): self.layers = nn.ModuleList([]) self.rel_pos_bias = RelativePositionBias() self.rel_pos_bias = RelativePositionBias(heads = heads) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, causal = True, **kwargs), Attention(dim = dim, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, Attention(dim = dim, heads = heads, causal = True, **kwargs), Attention(dim = dim, heads = heads, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, FeedForward(dim = dim) ])) Loading Loading @@ -337,6 +338,8 @@ class SemanticTransformer(nn.Module): self, *, dim, depth, heads, num_semantic_tokens, t5_name = DEFAULT_T5_NAME, has_condition = False, Loading @@ -358,7 +361,7 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property Loading Loading @@ -428,6 +431,8 @@ class CoarseTransformer(nn.Module): codebook_size, num_coarse_quantizers, dim, depth, heads, num_semantic_tokens, t5_name = DEFAULT_T5_NAME, has_condition = False, Loading @@ -450,7 +455,7 @@ class CoarseTransformer(nn.Module): codebook_size_with_eos = codebook_size + 1 self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim) self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading Loading @@ -574,6 +579,8 @@ class FineTransformer(nn.Module): num_fine_quantizers, codebook_size, dim, depth, heads, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -597,7 +604,7 @@ class FineTransformer(nn.Module): self.eos_id = codebook_size self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.codebook_size = codebook_size self.num_coarse_quantizers = num_coarse_quantizers Loading