Commit 8c4d01ad authored by Phil Wang's avatar Phil Wang
Browse files

only during training

parent 127958b9
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -933,7 +933,7 @@ class SemanticTransformerWrapper(nn.Module):
            input_ids = semantic_token_ids[:, :-1]

        self_attn_mask = None
        if self.mask_prob > 0.:
        if self.mask_prob > 0. and self.training:
            self_attn_mask = generate_mask_with_prob(input_ids.shape, self.mask_prob, input_ids.device)

        logits = self.transformer(
@@ -1118,7 +1118,7 @@ class CoarseTransformerWrapper(nn.Module):

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0:
        if self.mask_prob > 0 and self.training:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        # whether to early return the logits
@@ -1320,7 +1320,7 @@ class FineTransformerWrapper(nn.Module):

        # forgetful causal mask - structured dropout

        if self.mask_prob > 0:
        if self.mask_prob > 0 and self.training:
            self_attn_mask &= generate_mask_with_prob(self_attn_mask.shape, self.mask_prob, device = self_attn_mask.device)

        # early return the logits
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.67',
  version = '0.0.68',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',