Loading audiolm_pytorch/audiolm_pytorch.py +25 −8 Original line number Diff line number Diff line Loading @@ -185,12 +185,13 @@ class GEGLU(nn.Module): x, gate = x.chunk(2, dim = -1) return F.gelu(gate) * x def FeedForward(dim, mult = 4): def FeedForward(dim, mult = 4, dropout = 0.1): inner_dim = int(dim * 2 * mult / 3) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim * 2, bias = False), GEGLU(), nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias = False) ) Loading @@ -205,7 +206,8 @@ class Attention(nn.Module): dim_context = None, heads = 8, norm_context = False, num_null_kv = 0 num_null_kv = 0, dropout = 0.1 ): super().__init__() self.heads = heads Loading @@ -223,7 +225,10 @@ class Attention(nn.Module): self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), nn.Dropout(dropout) ) def forward( self, Loading Loading @@ -286,6 +291,8 @@ class Transformer(nn.Module): heads, dim_context = None, cross_attend = False, attn_dropout = 0.1, ff_dropout = 0.1, grad_shrink_alpha = 0.1, **kwargs ): Loading @@ -298,9 +305,9 @@ class Transformer(nn.Module): for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, heads = heads, causal = True, **kwargs), Attention(dim = dim, heads = heads, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, FeedForward(dim = dim) Attention(dim = dim, heads = heads, dropout = attn_dropout, causal = True, **kwargs), Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, FeedForward(dim = dim, dropout = ff_dropout) ])) self.norm = nn.LayerNorm(dim) Loading Loading @@ -341,6 +348,8 @@ class SemanticTransformer(nn.Module): depth, heads, num_semantic_tokens, attn_dropout = 0.1, ff_dropout = 0.1, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -361,7 +370,15 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property Loading Loading @@ -414,7 +431,7 @@ class SemanticTransformer(nn.Module): if return_loss: labels, ids = ids.clone(), ids[:, :-1] tokens = self.semantic_embedding(ids) tokens = get_embeds(self.semantic_embedding, ids) start_tokens = repeat(self.start_token, 'd -> b 1 d', b = ids.shape[0]) Loading Loading
audiolm_pytorch/audiolm_pytorch.py +25 −8 Original line number Diff line number Diff line Loading @@ -185,12 +185,13 @@ class GEGLU(nn.Module): x, gate = x.chunk(2, dim = -1) return F.gelu(gate) * x def FeedForward(dim, mult = 4): def FeedForward(dim, mult = 4, dropout = 0.1): inner_dim = int(dim * 2 * mult / 3) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim * 2, bias = False), GEGLU(), nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias = False) ) Loading @@ -205,7 +206,8 @@ class Attention(nn.Module): dim_context = None, heads = 8, norm_context = False, num_null_kv = 0 num_null_kv = 0, dropout = 0.1 ): super().__init__() self.heads = heads Loading @@ -223,7 +225,10 @@ class Attention(nn.Module): self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), nn.Dropout(dropout) ) def forward( self, Loading Loading @@ -286,6 +291,8 @@ class Transformer(nn.Module): heads, dim_context = None, cross_attend = False, attn_dropout = 0.1, ff_dropout = 0.1, grad_shrink_alpha = 0.1, **kwargs ): Loading @@ -298,9 +305,9 @@ class Transformer(nn.Module): for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, heads = heads, causal = True, **kwargs), Attention(dim = dim, heads = heads, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, FeedForward(dim = dim) Attention(dim = dim, heads = heads, dropout = attn_dropout, causal = True, **kwargs), Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, FeedForward(dim = dim, dropout = ff_dropout) ])) self.norm = nn.LayerNorm(dim) Loading Loading @@ -341,6 +348,8 @@ class SemanticTransformer(nn.Module): depth, heads, num_semantic_tokens, attn_dropout = 0.1, ff_dropout = 0.1, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -361,7 +370,15 @@ class SemanticTransformer(nn.Module): self.eos_id = num_semantic_tokens self.pad_id = pad_id self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.transformer = Transformer(dim = dim, depth = depth, heads = heads, attn_dropout = attn_dropout, ff_dropout = ff_dropout, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property Loading Loading @@ -414,7 +431,7 @@ class SemanticTransformer(nn.Module): if return_loss: labels, ids = ids.clone(), ids[:, :-1] tokens = self.semantic_embedding(ids) tokens = get_embeds(self.semantic_embedding, ids) start_tokens = repeat(self.start_token, 'd -> b 1 d', b = ids.shape[0]) Loading