Commit 440f4e84 authored by Phil Wang's avatar Phil Wang
Browse files

implement forgetful causal masking for free improvement

parent 91525995
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -275,3 +275,13 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
    primaryClass = {cs.CV}
}
```

```bibtex
@article{Liu2022FCMFC,
    title   = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
    author  = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13432}
}
```
+92 −23
Original line number Diff line number Diff line
@@ -49,6 +49,17 @@ def eval_decorator(fn):
        return out
    return inner

# tensor helpers

def generate_mask_with_prob(shape, mask_prob, device):
    seq = shape[-1]
    rand = torch.randn(shape, device = device)
    rand[:, 0] = -torch.finfo(rand.dtype).max
    num_mask = min(int(seq * mask_prob), seq - 1)
    indices = rand.topk(num_mask, dim = -1).indices
    mask = ~torch.zeros(shape, device = device).scatter(1, indices, 1.).bool()
    return mask

# attention related utils

def grad_shrink(t, alpha = 0.1):
@@ -291,8 +302,8 @@ class Transformer(nn.Module):
        heads,
        dim_context = None,
        cross_attend = False,
        attn_dropout = 0.1,
        ff_dropout = 0.1,
        attn_dropout = 0.,
        ff_dropout = 0.,
        grad_shrink_alpha = 0.1,
        **kwargs
    ):
@@ -346,10 +357,10 @@ class SemanticTransformer(nn.Module):
        *,
        dim,
        depth,
        heads,
        num_semantic_tokens,
        attn_dropout = 0.1,
        ff_dropout = 0.1,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
@@ -370,7 +381,8 @@ class SemanticTransformer(nn.Module):
        self.eos_id = num_semantic_tokens
        self.pad_id = pad_id

        self.transformer = Transformer(dim = dim,
        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
@@ -378,7 +390,9 @@ class SemanticTransformer(nn.Module):
            dim_context = get_encoded_dim(t5_name),
            cross_attend = has_condition,
            grad_shrink_alpha = grad_shrink_alpha,
                                       **kwargs)
            **kwargs
        )

        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    @property
@@ -406,6 +420,7 @@ class SemanticTransformer(nn.Module):
        return_loss = False,
        text: Optional[List[str]] = None,
        text_embeds = None,
        self_attn_mask = None,
        cond_drop_prob = None,
        unique_consecutive = None
    ):
@@ -437,7 +452,10 @@ class SemanticTransformer(nn.Module):

        tokens = torch.cat((start_tokens, tokens), dim = 1)

        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)
        if exists(self_attn_mask):
            self_attn_mask = F.pad(self_attn_mask, (1, 0), value = True)

        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)
        return self.to_logits(tokens)

@beartype
@@ -449,8 +467,10 @@ class CoarseTransformer(nn.Module):
        num_coarse_quantizers,
        dim,
        depth,
        heads,
        num_semantic_tokens,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
@@ -472,7 +492,17 @@ class CoarseTransformer(nn.Module):
        codebook_size_with_eos = codebook_size + 1
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)

        self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            dim_context = get_encoded_dim(t5_name),
            cross_attend = has_condition,
            grad_shrink_alpha = grad_shrink_alpha,
            **kwargs
        )

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
@@ -597,7 +627,9 @@ class FineTransformer(nn.Module):
        codebook_size,
        dim,
        depth,
        heads,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
@@ -621,7 +653,17 @@ class FineTransformer(nn.Module):

        self.eos_id = codebook_size

        self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            dim_context = get_encoded_dim(t5_name),
            cross_attend = has_condition,
            grad_shrink_alpha = grad_shrink_alpha,
            **kwargs
        )

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
@@ -765,7 +807,8 @@ class SemanticTransformerWrapper(nn.Module):
        transformer: SemanticTransformer,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        pad_id = -1,
        unique_consecutive = True
        unique_consecutive = True,
        mask_prob = 0.15
    ):
        super().__init__()
        self.wav2vec = wav2vec
@@ -775,6 +818,7 @@ class SemanticTransformerWrapper(nn.Module):
        self.unique_consecutive = unique_consecutive
        self.pad_id = pad_id
        self.eos_id = transformer.eos_id
        self.mask_prob = mask_prob

    @property
    def device(self):
@@ -888,10 +932,15 @@ class SemanticTransformerWrapper(nn.Module):
        if return_loss:
            input_ids = semantic_token_ids[:, :-1]

        self_attn_mask = None
        if self.mask_prob > 0.:
            self_attn_mask = generate_mask_with_prob(input_ids.shape, self.mask_prob, input_ids.device)

        logits = self.transformer(
            ids = input_ids,
            text = text,
            text_embeds = text_embeds,
            self_attn_mask = self_attn_mask,
            **kwargs
        )

@@ -916,7 +965,8 @@ class CoarseTransformerWrapper(nn.Module):
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        pad_id = -1,
        unique_consecutive = True,
        semantic_cross_entropy_loss_weight = 1.
        semantic_cross_entropy_loss_weight = 1.,
        mask_prob = 0.15
    ):
        super().__init__()
        self.soundstream = soundstream
@@ -932,6 +982,8 @@ class CoarseTransformerWrapper(nn.Module):
        self.semantic_eos_id = transformer.semantic_eos_id
        self.coarse_eos_id = transformer.coarse_eos_id

        self.mask_prob = mask_prob

    @property
    def device(self):
        return next(self.parameters()).device
@@ -1064,6 +1116,13 @@ class CoarseTransformerWrapper(nn.Module):
            **kwargs
        )

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        # whether to early return the logits

        if not return_loss:
            return semantic_logits, coarse_logits

@@ -1100,7 +1159,8 @@ class FineTransformerWrapper(nn.Module):
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        coarse_cross_entropy_loss_weight = 1.,
        pad_id = -1
        pad_id = -1,
        mask_prob = 0.15
    ):
        super().__init__()
        self.soundstream = soundstream
@@ -1115,6 +1175,8 @@ class FineTransformerWrapper(nn.Module):
        self.pad_id = pad_id
        self.coarse_cross_entropy_loss_weight = coarse_cross_entropy_loss_weight

        self.mask_prob = mask_prob

    @property
    def device(self):
        return next(self.parameters()).device
@@ -1256,6 +1318,13 @@ class FineTransformerWrapper(nn.Module):
            **kwargs
        )

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        # early return the logits

        if not return_loss:
            return coarse_logits, fine_logits

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.66',
  version = '0.0.67',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',