Loading musiclm_pytorch/musiclm_pytorch.py +48 −2 Original line number Diff line number Diff line Loading @@ -20,6 +20,24 @@ class LayerNorm(nn.Module): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) # feedforward class GEGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim = -1) return F.gelu(gate) * x def FeedForward(dim, mult = 4, dropout = 0.): dim_hidden = int(dim * mult * 2 / 3) return nn.Sequential( LayerNorm(dim), nn.Linear(dim, dim_hidden * 2, bias = False), GEGLU(), nn.Dropout(dropout), nn.Linear(dim_hidden, dim, bias = False) ) # attention class Attention(nn.Module): Loading @@ -29,8 +47,7 @@ class Attention(nn.Module): causal = False, dim_head = 64, heads = 8, num_null_kv = 0, dropout = 0.1 dropout = 0. ): super().__init__() self.heads = heads Loading Loading @@ -98,6 +115,35 @@ class Attention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # transformer class Transformer(nn.Module): def __init__( self, dim, depth, dim_head = 64, heads = 8, attn_dropout = 0., ff_mult = 4, ff_dropout = 0. ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout), FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout), ])) def forward(self, x, mask = None): for attn, ff in self.layers: x = attn(x, mask = mask) + x x = ff(x) + x return x # main classes class MuLaN(nn.Module): Loading Loading
musiclm_pytorch/musiclm_pytorch.py +48 −2 Original line number Diff line number Diff line Loading @@ -20,6 +20,24 @@ class LayerNorm(nn.Module): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) # feedforward class GEGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim = -1) return F.gelu(gate) * x def FeedForward(dim, mult = 4, dropout = 0.): dim_hidden = int(dim * mult * 2 / 3) return nn.Sequential( LayerNorm(dim), nn.Linear(dim, dim_hidden * 2, bias = False), GEGLU(), nn.Dropout(dropout), nn.Linear(dim_hidden, dim, bias = False) ) # attention class Attention(nn.Module): Loading @@ -29,8 +47,7 @@ class Attention(nn.Module): causal = False, dim_head = 64, heads = 8, num_null_kv = 0, dropout = 0.1 dropout = 0. ): super().__init__() self.heads = heads Loading Loading @@ -98,6 +115,35 @@ class Attention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # transformer class Transformer(nn.Module): def __init__( self, dim, depth, dim_head = 64, heads = 8, attn_dropout = 0., ff_mult = 4, ff_dropout = 0. ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout), FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout), ])) def forward(self, x, mask = None): for attn, ff in self.layers: x = attn(x, mask = mask) + x x = ff(x) + x return x # main classes class MuLaN(nn.Module): Loading