Commit d5890a24 authored by Phil Wang's avatar Phil Wang
Browse files

switch to continuous positional bias, for the length extrapolation at inference time

parent 8e3d1979
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -378,3 +378,14 @@ sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_
    volume  = {abs/2104.05707}
}
```

```bibtex
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```
 No newline at end of file
+23 −33
Original line number Diff line number Diff line
@@ -160,47 +160,37 @@ class LayerNorm(nn.Module):
# relative positional bias

class RelativePositionBias(nn.Module):
    """ from https://arxiv.org/abs/2111.09883 """

    def __init__(
        self,
        num_buckets = 32,
        max_distance = 128,
        heads = 8
        *,
        dim,
        heads,
        layers = 3
    ):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        ret = 0

        n = -relative_position
        n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()

        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
        self.net = nn.ModuleList([])
        self.net.append(nn.Sequential(nn.Linear(1, dim), nn.SiLU()))

        ret += torch.where(is_small, n, val_if_large)
        return ret
        for _ in range(layers - 1):
            self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))

    def forward(self, i, j, device):
        self.net.append(nn.Linear(dim, heads))

        q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
    def forward(self, n, device = torch.device('cpu')):
        pos = torch.arange(n, device = device)
        rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j'))
        rel_pos += (n - 1)

        rel_pos = k_pos[None, :] - q_pos[:, None]
        x = torch.arange(-n + 1, n, device = device).float()
        x = rearrange(x, '... -> ... 1')

        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        for layer in self.net:
            x = layer(x)

        return rearrange(values, 'i j h -> h i j')
        x = x[rel_pos]
        return rearrange(x, 'i j h -> h i j')

# feedforward

@@ -381,7 +371,7 @@ class Transformer(nn.Module):

        self.layers = nn.ModuleList([])

        self.rel_pos_bias = RelativePositionBias(heads = heads)
        self.rel_pos_bias = RelativePositionBias(dim = dim // 2, heads = heads)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
@@ -405,7 +395,7 @@ class Transformer(nn.Module):

        x = self.grad_shrink(x) # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability

        rel_pos_bias = self.rel_pos_bias(n, n, device = device)
        rel_pos_bias = self.rel_pos_bias(n, device = device)

        self_attn_kwargs = dict()
        if self.cond_as_self_attn_prefix:
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.6.1',
  version = '0.6.2',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',