Commit 450af495 authored by Phil Wang's avatar Phil Wang
Browse files

fix coarse cross entropy loss weights

parent 09c79a04
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ def hinge_gen_loss(fake):
    return -fake.mean()

def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)
    return nn.LeakyReLU(p)

def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]
@@ -1112,7 +1112,7 @@ class CoarseTransformerWrapper(nn.Module):
        coarse_logits, semantic_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, semantic_logits))

        if self.unique_consecutive:
            num_coarse_logits, num_semantic_logits = coarse_logits.shape[0] * coarse_logits.shape[-1], self_attn_mask.sum()
            num_coarse_logits, num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum()
        else:
            num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1]

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.12',
  version = '0.0.14',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',