Loading audiolm_pytorch/audiolm_pytorch.py +54 −4 Original line number Diff line number Diff line Loading @@ -14,7 +14,18 @@ class SoundStream(nn.Module): def forward(self, x): return x # classes # feedforward def FeedForward(dim, mult = 4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias = False), nn.GELU(), nn.Linear(inner_dim, dim, bias = False) ) # attention class Attention(nn.Module): def __init__( Loading @@ -28,14 +39,18 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) def forward(self, x): x = self.norm(x) q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d'), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) q = q * self.scale Loading @@ -48,11 +63,46 @@ class Attention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # transformer class Transformer(nn.Module): def __init__( self, *, dim, depth, **kwargs ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, **kwargs), FeedForward(dim = dim) ])) self.norm = nn.LayerNorm(dim) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return self.norm(x) # audio LM class AudioLM(nn.Module): def __init__(self): def __init__( self, *, dim, depth, **kwargs ): super().__init__() self.transformer = Transformer(dim = dim, depth = depth, **kwargs) def forward(self, x): return x return self.transformer(x) Loading
audiolm_pytorch/audiolm_pytorch.py +54 −4 Original line number Diff line number Diff line Loading @@ -14,7 +14,18 @@ class SoundStream(nn.Module): def forward(self, x): return x # classes # feedforward def FeedForward(dim, mult = 4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias = False), nn.GELU(), nn.Linear(inner_dim, dim, bias = False) ) # attention class Attention(nn.Module): def __init__( Loading @@ -28,14 +39,18 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) def forward(self, x): x = self.norm(x) q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d'), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) q = q * self.scale Loading @@ -48,11 +63,46 @@ class Attention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # transformer class Transformer(nn.Module): def __init__( self, *, dim, depth, **kwargs ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, **kwargs), FeedForward(dim = dim) ])) self.norm = nn.LayerNorm(dim) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return self.norm(x) # audio LM class AudioLM(nn.Module): def __init__(self): def __init__( self, *, dim, depth, **kwargs ): super().__init__() self.transformer = Transformer(dim = dim, depth = depth, **kwargs) def forward(self, x): return x return self.transformer(x)