Loading audiolm_pytorch/audiolm_pytorch.py +10 −9 Original line number Diff line number Diff line Loading @@ -1262,6 +1262,11 @@ class CoarseTransformerWrapper(nn.Module): coarse_token_len = coarse_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True) # attend to semantic bos and all coarse tokens # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) semantic_logits, coarse_logits = self.transformer( semantic_token_ids = semantic_token_ids, coarse_token_ids = coarse_token_ids, Loading @@ -1271,10 +1276,6 @@ class CoarseTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # whether to early return the logits Loading Loading @@ -1482,6 +1483,11 @@ class FineTransformerWrapper(nn.Module): fine_token_seq_len = fine_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, fine_token_seq_len + 1), value = True) # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) coarse_logits, fine_logits = self.transformer( coarse_token_ids = coarse_token_ids, fine_token_ids = fine_token_ids, Loading @@ -1491,11 +1497,6 @@ class FineTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # early return the logits if not return_loss: Loading Loading
audiolm_pytorch/audiolm_pytorch.py +10 −9 Original line number Diff line number Diff line Loading @@ -1262,6 +1262,11 @@ class CoarseTransformerWrapper(nn.Module): coarse_token_len = coarse_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True) # attend to semantic bos and all coarse tokens # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) semantic_logits, coarse_logits = self.transformer( semantic_token_ids = semantic_token_ids, coarse_token_ids = coarse_token_ids, Loading @@ -1271,10 +1276,6 @@ class CoarseTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # whether to early return the logits Loading Loading @@ -1482,6 +1483,11 @@ class FineTransformerWrapper(nn.Module): fine_token_seq_len = fine_token_ids.shape[-1] self_attn_mask = F.pad(self_attn_mask, (1, fine_token_seq_len + 1), value = True) # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) coarse_logits, fine_logits = self.transformer( coarse_token_ids = coarse_token_ids, fine_token_ids = fine_token_ids, Loading @@ -1491,11 +1497,6 @@ class FineTransformerWrapper(nn.Module): **kwargs ) # forgetful causal mask - structured dropout if self.mask_prob > 0 and self.training: self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device) # early return the logits if not return_loss: Loading