Commit d40e4119 authored by Phil Wang's avatar Phil Wang
Browse files

offset by multiple of codebook size across quantizers for both coarse and fine

parent 0ac5e4f5
Loading
Loading
Loading
Loading
+22 −1
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@ from vector_quantize_pytorch import ResidualVQ
def exists(val):
    return val is not None

def ceil_div(numer, denom):
    return (numer + denom - 1) // denom

def round_down_nearest_multiple(val, mult):
    return (val // mult) * mult

@@ -601,6 +604,7 @@ class CoarseTransformer(nn.Module):

        self.transformer = Transformer(dim = dim, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers

        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens)
@@ -611,10 +615,15 @@ class CoarseTransformer(nn.Module):
        semantic_token_ids,
        coarse_token_ids,
    ):
        b = semantic_token_ids.shape[0]
        b, device = semantic_token_ids.shape[0], semantic_token_ids.device

        coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids))

        offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
        offsets = repeat(offsets, 'q -> 1 (n q)', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers))
        offsets = offsets[:, :coarse_token_ids.shape[-1]]
        coarse_token_ids = coarse_token_ids + offsets

        semantic_tokens = self.semantic_embedding(semantic_token_ids)
        coarse_tokens = self.coarse_embedding(coarse_token_ids)

@@ -687,10 +696,22 @@ class FineTransformer(nn.Module):
        coarse_token_ids,
        fine_token_ids
    ):
        device = coarse_token_ids.device

        coarse_token_ids, fine_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, fine_token_ids))

        b, n = coarse_token_ids.shape

        coarse_offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
        coarse_offsets = repeat(coarse_offsets, 'q -> 1 (n q)', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers))
        coarse_offsets = coarse_offsets[:, :coarse_token_ids.shape[-1]]
        coarse_token_ids = coarse_token_ids + coarse_offsets

        fine_offsets = self.codebook_size * torch.arange(self.num_fine_quantizers, device = device)
        fine_offsets = repeat(fine_offsets, 'q -> 1 (n q)', n = ceil_div(fine_token_ids.shape[-1], self.num_fine_quantizers))
        fine_offsets = fine_offsets[:, :fine_token_ids.shape[-1]]
        fine_token_ids = fine_token_ids + fine_offsets

        coarse_tokens = self.coarse_embedding(coarse_token_ids)
        fine_tokens = self.fine_embedding(fine_token_ids)

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.3',
  version = '0.0.4',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',