Commit 277dabf5 authored by Phil Wang's avatar Phil Wang
Browse files

they use three causal attention networks for semantic, coarse, fine. prepare...

they use three causal attention networks for semantic, coarse, fine. prepare to open source soundstream
parent 9b1fef52
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from audiolm_pytorch.audiolm_pytorch import AudioLM
from audiolm_pytorch.audiolm_pytorch import SoundStream
+68 −4
Original line number Diff line number Diff line
import math

import torch
import torch.nn.functional as F
from torch import nn, einsum
@@ -5,6 +7,11 @@ from einops import rearrange

from vector_quantize_pytorch import VectorQuantize as VQ

# helper functions

def exists(val):
    return val is not None

# sound stream

class SoundStream(nn.Module):
@@ -14,6 +21,51 @@ class SoundStream(nn.Module):
    def forward(self, x):
        return x

# relative positional bias

class RelativePositionBias(nn.Module):
    def __init__(
        self,
        num_buckets = 32,
        max_distance = 128,
        heads = 8
    ):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        ret = 0

        n = -relative_position
        n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()

        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, i, j, device):

        q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)

        rel_pos = k_pos[None, :] - q_pos[:, None]

        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
        values = self.relative_attention_bias(rp_bucket)

        return rearrange(values, 'i j h -> h i j')

# feedforward

def FeedForward(dim, mult = 4):
@@ -45,7 +97,7 @@ class Attention(nn.Module):
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
    def forward(self, x, attn_bias = None):
        x = self.norm(x)

        q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
@@ -56,6 +108,9 @@ class Attention(nn.Module):

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(attn_bias):
            sim = sim + attn_bias

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -76,6 +131,8 @@ class Transformer(nn.Module):
        super().__init__()
        self.layers = nn.ModuleList([])

        self.rel_pos_bias = RelativePositionBias()

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, **kwargs),
@@ -85,8 +142,12 @@ class Transformer(nn.Module):
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        n, device = x.shape[1], x.device

        rel_pos_bias = self.rel_pos_bias(n, n, device = device)

        for attn, ff in self.layers:
            x = attn(x) + x
            x = attn(x, attn_bias = rel_pos_bias) + x
            x = ff(x) + x

        return self.norm(x)
@@ -102,7 +163,10 @@ class AudioLM(nn.Module):
        **kwargs
    ):
        super().__init__()
        self.transformer = Transformer(dim = dim, depth = depth, **kwargs)
        self.attend_semantic = Transformer(dim = dim, depth = depth, **kwargs)
        self.attend_coarse = Transformer(dim = dim, depth = depth, **kwargs)
        self.attend_fine = Transformer(dim = dim, depth = depth, **kwargs)

    def forward(self, x):
        return self.transformer(x)
        x = self.attend_semantic(x)
        return x