Loading audiolm_pytorch/audiolm_pytorch.py +27 −12 Original line number Diff line number Diff line Loading @@ -856,7 +856,8 @@ class CoarseTransformerWrapper(nn.Module): soundstream: Optional[SoundStream] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, pad_id = -1, unique_consecutive = True unique_consecutive = True, semantic_cross_entropy_loss_weight = 1. ): super().__init__() self.soundstream = soundstream Loading @@ -866,6 +867,8 @@ class CoarseTransformerWrapper(nn.Module): self.unique_consecutive = unique_consecutive self.pad_id = pad_id self.semantic_cross_entropy_loss_weight = semantic_cross_entropy_loss_weight self.num_coarse_quantizers = transformer.num_coarse_quantizers self.eos_id = transformer.coarse_eos_id Loading Loading @@ -1009,6 +1012,8 @@ class CoarseTransformerWrapper(nn.Module): else: num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] semantic_loss = 0. if self.semantic_cross_entropy_loss_weight > 0: semantic_loss = F.cross_entropy( semantic_logits, semantic_labels, Loading @@ -1020,7 +1025,10 @@ class CoarseTransformerWrapper(nn.Module): coarse_labels ) return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits) return ( semantic_loss * num_semantic_logits * self.semantic_cross_entropy_loss_weight + coarse_loss * num_coarse_logits ) / (num_semantic_logits + num_coarse_logits) @typechecked class FineTransformerWrapper(nn.Module): Loading @@ -1029,6 +1037,7 @@ class FineTransformerWrapper(nn.Module): *, transformer: FineTransformer, soundstream: Optional[SoundStream] = None, coarse_cross_entropy_loss_weight = 1., pad_id = -1 ): super().__init__() Loading @@ -1042,6 +1051,7 @@ class FineTransformerWrapper(nn.Module): assert self.num_coarse_quantizers > 0 self.pad_id = pad_id self.coarse_cross_entropy_loss_weight = coarse_cross_entropy_loss_weight @property def device(self): Loading Loading @@ -1177,6 +1187,8 @@ class FineTransformerWrapper(nn.Module): num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1] coarse_loss = 0. if self.coarse_cross_entropy_loss_weight > 0: coarse_loss = F.cross_entropy( coarse_logits, coarse_labels Loading @@ -1187,7 +1199,10 @@ class FineTransformerWrapper(nn.Module): fine_labels ) return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits) return ( coarse_loss * num_coarse_logits * self.coarse_cross_entropy_loss_weight + fine_loss * num_fine_logits ) / (num_coarse_logits + num_fine_logits) # audio LM Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.45', version = '0.0.46', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/audiolm_pytorch.py +27 −12 Original line number Diff line number Diff line Loading @@ -856,7 +856,8 @@ class CoarseTransformerWrapper(nn.Module): soundstream: Optional[SoundStream] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, pad_id = -1, unique_consecutive = True unique_consecutive = True, semantic_cross_entropy_loss_weight = 1. ): super().__init__() self.soundstream = soundstream Loading @@ -866,6 +867,8 @@ class CoarseTransformerWrapper(nn.Module): self.unique_consecutive = unique_consecutive self.pad_id = pad_id self.semantic_cross_entropy_loss_weight = semantic_cross_entropy_loss_weight self.num_coarse_quantizers = transformer.num_coarse_quantizers self.eos_id = transformer.coarse_eos_id Loading Loading @@ -1009,6 +1012,8 @@ class CoarseTransformerWrapper(nn.Module): else: num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] semantic_loss = 0. if self.semantic_cross_entropy_loss_weight > 0: semantic_loss = F.cross_entropy( semantic_logits, semantic_labels, Loading @@ -1020,7 +1025,10 @@ class CoarseTransformerWrapper(nn.Module): coarse_labels ) return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits) return ( semantic_loss * num_semantic_logits * self.semantic_cross_entropy_loss_weight + coarse_loss * num_coarse_logits ) / (num_semantic_logits + num_coarse_logits) @typechecked class FineTransformerWrapper(nn.Module): Loading @@ -1029,6 +1037,7 @@ class FineTransformerWrapper(nn.Module): *, transformer: FineTransformer, soundstream: Optional[SoundStream] = None, coarse_cross_entropy_loss_weight = 1., pad_id = -1 ): super().__init__() Loading @@ -1042,6 +1051,7 @@ class FineTransformerWrapper(nn.Module): assert self.num_coarse_quantizers > 0 self.pad_id = pad_id self.coarse_cross_entropy_loss_weight = coarse_cross_entropy_loss_weight @property def device(self): Loading Loading @@ -1177,6 +1187,8 @@ class FineTransformerWrapper(nn.Module): num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1] coarse_loss = 0. if self.coarse_cross_entropy_loss_weight > 0: coarse_loss = F.cross_entropy( coarse_logits, coarse_labels Loading @@ -1187,7 +1199,10 @@ class FineTransformerWrapper(nn.Module): fine_labels ) return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits) return ( coarse_loss * num_coarse_logits * self.coarse_cross_entropy_loss_weight + fine_loss * num_fine_logits ) / (num_coarse_logits + num_fine_logits) # audio LM Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.45', version = '0.0.46', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading