Commit e5408fbd authored by Phil Wang's avatar Phil Wang
Browse files

correct weighting of cross entropy losses

parent 2b6c5662
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -1017,6 +1017,8 @@ class FineTransformerWrapper(nn.Module):

        coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

        num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1]

        coarse_loss = F.cross_entropy(
            coarse_logits,
            coarse_labels
@@ -1027,7 +1029,7 @@ class FineTransformerWrapper(nn.Module):
            fine_labels
        )

        return (coarse_loss + fine_loss) * 0.5
        return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits)

class CoarseTransformerWrapper(nn.Module):
    def __init__(
@@ -1096,6 +1098,8 @@ class CoarseTransformerWrapper(nn.Module):

        coarse_logits, semantic_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, semantic_logits))

        num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1]

        semantic_loss = F.cross_entropy(
            semantic_logits,
            semantic_labels
@@ -1106,7 +1110,7 @@ class CoarseTransformerWrapper(nn.Module):
            coarse_labels
        )

        return (semantic_loss + coarse_loss) * 0.5
        return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits)

# audio LM

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.10',
  version = '0.0.11',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',