Commit f7756f56 authored by Phil Wang's avatar Phil Wang
Browse files

a simple measure for greater transformer training stability

parent 5b24b4f5
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -133,3 +133,14 @@ loss.backward()
    url     = {https://twitter.com/rivershavewings}
}
```

```bibtex
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```
+16 −3
Original line number Diff line number Diff line
@@ -64,6 +64,11 @@ def gradient_penalty(images, output, weight = 10):
    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# attention related utils

def grad_shrink(t, alpha = 0.1):
    return t * alpha + t.detach() * (1 - alpha)

# classifier free guidance functions

def uniform(shape, device):
@@ -632,9 +637,12 @@ class Transformer(nn.Module):
        depth,
        dim_context = None,
        cross_attend = False,
        grad_shrink_alpha = 0.1,
        **kwargs
    ):
        super().__init__()
        self.grad_shrink = partial(grad_shrink, alpha = grad_shrink_alpha)

        self.layers = nn.ModuleList([])

        self.rel_pos_bias = RelativePositionBias()
@@ -657,6 +665,8 @@ class Transformer(nn.Module):
    ):
        n, device = x.shape[1], x.device

        x = self.grad_shrink(x) # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability

        rel_pos_bias = self.rel_pos_bias(n, n, device = device)

        for attn, cross_attn, ff in self.layers:
@@ -684,6 +694,7 @@ class SemanticTransformer(nn.Module):
        cond_drop_prob = 0.5,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        unique_consecutive = True,
        grad_shrink_alpha = 0.1,
        pad_id = -1,
        **kwargs
    ):
@@ -701,7 +712,7 @@ class SemanticTransformer(nn.Module):
        self.pad_id = pad_id

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    def forward(
@@ -778,6 +789,7 @@ class CoarseTransformer(nn.Module):
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        **kwargs
    ):
@@ -796,7 +808,7 @@ class CoarseTransformer(nn.Module):
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
@@ -890,6 +902,7 @@ class FineTransformer(nn.Module):
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        **kwargs
    ):
        super().__init__()
@@ -906,7 +919,7 @@ class FineTransformer(nn.Module):

        self.eos_id = codebook_size

        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.22',
  version = '0.0.23',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',