Loading audiolm_pytorch/audiolm_pytorch.py +6 −2 Original line number Diff line number Diff line Loading @@ -1463,12 +1463,16 @@ class CoarseTransformerWrapper(nn.Module): coarse_logits, semantic_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, semantic_logits)) if self.unique_consecutive: num_coarse_logits, num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum() num_coarse_logits, _num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum() else: num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] num_coarse_logits, _num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] semantic_loss = 0. num_semantic_logits = 0 if self.semantic_cross_entropy_loss_weight > 0 and exists(semantic_logits): num_semantic_logits = _num_semantic_logits semantic_loss = F.cross_entropy( semantic_logits, semantic_labels, Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.30.2' __version__ = '0.30.3' Loading
audiolm_pytorch/audiolm_pytorch.py +6 −2 Original line number Diff line number Diff line Loading @@ -1463,12 +1463,16 @@ class CoarseTransformerWrapper(nn.Module): coarse_logits, semantic_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, semantic_logits)) if self.unique_consecutive: num_coarse_logits, num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum() num_coarse_logits, _num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum() else: num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] num_coarse_logits, _num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1] semantic_loss = 0. num_semantic_logits = 0 if self.semantic_cross_entropy_loss_weight > 0 and exists(semantic_logits): num_semantic_logits = _num_semantic_logits semantic_loss = F.cross_entropy( semantic_logits, semantic_labels, Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.30.2' __version__ = '0.30.3'