Commit 84bc71d7 authored by Phil Wang's avatar Phil Wang
Browse files

allow for turning off cross entropy loss to preceding tokens (semantic,...

allow for turning off cross entropy loss to preceding tokens (semantic, coarse) in coarse and fine transformers
parent b0fbb485
Loading
Loading
Loading
Loading
+27 −12
Original line number Diff line number Diff line
@@ -856,7 +856,8 @@ class CoarseTransformerWrapper(nn.Module):
        soundstream: Optional[SoundStream]  = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        pad_id = -1,
        unique_consecutive = True
        unique_consecutive = True,
        semantic_cross_entropy_loss_weight = 1.
    ):
        super().__init__()
        self.soundstream = soundstream
@@ -866,6 +867,8 @@ class CoarseTransformerWrapper(nn.Module):
        self.unique_consecutive = unique_consecutive
        self.pad_id = pad_id

        self.semantic_cross_entropy_loss_weight = semantic_cross_entropy_loss_weight

        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        self.eos_id = transformer.coarse_eos_id

@@ -1009,6 +1012,8 @@ class CoarseTransformerWrapper(nn.Module):
        else:
            num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1]

        semantic_loss = 0.
        if self.semantic_cross_entropy_loss_weight > 0:
            semantic_loss = F.cross_entropy(
                semantic_logits,
                semantic_labels,
@@ -1020,7 +1025,10 @@ class CoarseTransformerWrapper(nn.Module):
            coarse_labels
        )

        return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits)
        return (
            semantic_loss * num_semantic_logits * self.semantic_cross_entropy_loss_weight +
            coarse_loss * num_coarse_logits
        ) / (num_semantic_logits + num_coarse_logits)

@typechecked
class FineTransformerWrapper(nn.Module):
@@ -1029,6 +1037,7 @@ class FineTransformerWrapper(nn.Module):
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        coarse_cross_entropy_loss_weight = 1.,
        pad_id = -1
    ):
        super().__init__()
@@ -1042,6 +1051,7 @@ class FineTransformerWrapper(nn.Module):
        assert self.num_coarse_quantizers > 0

        self.pad_id = pad_id
        self.coarse_cross_entropy_loss_weight = coarse_cross_entropy_loss_weight

    @property
    def device(self):
@@ -1177,6 +1187,8 @@ class FineTransformerWrapper(nn.Module):

        num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1]

        coarse_loss = 0.
        if self.coarse_cross_entropy_loss_weight > 0:
            coarse_loss = F.cross_entropy(
                coarse_logits,
                coarse_labels
@@ -1187,7 +1199,10 @@ class FineTransformerWrapper(nn.Module):
            fine_labels
        )

        return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits)
        return (
            coarse_loss * num_coarse_logits * self.coarse_cross_entropy_loss_weight +
            fine_loss * num_fine_logits
        ) / (num_coarse_logits + num_fine_logits)

# audio LM

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.45',
  version = '0.0.46',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',