Commit e98e856b authored by zhvng's avatar zhvng
Browse files

bugfix - forgetful mask was not used

parent 41861d49
Loading
Loading
Loading
Loading
+10 −9
Original line number Diff line number Diff line
@@ -1262,6 +1262,11 @@ class CoarseTransformerWrapper(nn.Module):
        coarse_token_len = coarse_token_ids.shape[-1]
        self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True) # attend to semantic bos and all coarse tokens

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0 and self.training:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        semantic_logits, coarse_logits = self.transformer(
            semantic_token_ids = semantic_token_ids,
            coarse_token_ids = coarse_token_ids,
@@ -1271,10 +1276,6 @@ class CoarseTransformerWrapper(nn.Module):
            **kwargs
        )

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0 and self.training:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        # whether to early return the logits

@@ -1482,6 +1483,11 @@ class FineTransformerWrapper(nn.Module):
        fine_token_seq_len = fine_token_ids.shape[-1]
        self_attn_mask = F.pad(self_attn_mask, (1, fine_token_seq_len + 1), value = True)

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0 and self.training:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        coarse_logits, fine_logits = self.transformer(
            coarse_token_ids = coarse_token_ids,
            fine_token_ids = fine_token_ids,
@@ -1491,11 +1497,6 @@ class FineTransformerWrapper(nn.Module):
            **kwargs
        )

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0 and self.training:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        # early return the logits

        if not return_loss: