Commit 8db0afee authored by Phil Wang's avatar Phil Wang
Browse files

prepare for musiclm inference and also address...

prepare for musiclm inference and also address https://github.com/lucidrains/audiolm-pytorch/issues/67
parent 6919830b
Loading
Loading
Loading
Loading
+33 −14
Original line number Diff line number Diff line
import math
from functools import partial
from functools import partial, wraps

from beartype.typing import Optional, Union, List
from beartype import beartype
@@ -32,6 +32,14 @@ def exists(val):
def default(val, d):
    return val if exists(val) else d

def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner

def ceil_div(numer, denom):
    return (numer + denom - 1) // denom

@@ -551,6 +559,7 @@ class CoarseTransformer(nn.Module):
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        project_semantic_logits = True,
        **kwargs
    ):
        super().__init__()
@@ -588,7 +597,7 @@ class CoarseTransformer(nn.Module):
        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers

        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1)
        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) if project_semantic_logits else None
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim))

    @property
@@ -675,7 +684,7 @@ class CoarseTransformer(nn.Module):

        # semantic logits

        semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits else None
        semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits and exists(self.to_semantic_logits) else None

        # get coarse logits

@@ -718,6 +727,7 @@ class FineTransformer(nn.Module):
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        project_coarse_logits = True,
        **kwargs
    ):
        super().__init__()
@@ -756,7 +766,7 @@ class FineTransformer(nn.Module):
        self.num_coarse_quantizers = num_coarse_quantizers
        self.num_fine_quantizers = num_fine_quantizers

        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim))
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) if project_coarse_logits else None
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim))

    @property
@@ -856,7 +866,7 @@ class FineTransformer(nn.Module):

        coarse_logits = None

        if not return_only_fine_logits:
        if not return_only_fine_logits and exists(self.coarse_logit_weights):
            coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens)

            coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c')
@@ -1218,7 +1228,7 @@ class CoarseTransformerWrapper(nn.Module):
        if not return_loss:
            return semantic_logits, coarse_logits

        coarse_logits, semantic_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, semantic_logits))
        coarse_logits, semantic_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, semantic_logits))

        if self.unique_consecutive:
            num_coarse_logits, num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum()
@@ -1226,7 +1236,7 @@ class CoarseTransformerWrapper(nn.Module):
            num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1]

        semantic_loss = 0.
        if self.semantic_cross_entropy_loss_weight > 0:
        if self.semantic_cross_entropy_loss_weight > 0 and exists(semantic_logits):
            semantic_loss = F.cross_entropy(
                semantic_logits,
                semantic_labels,
@@ -1424,12 +1434,16 @@ class FineTransformerWrapper(nn.Module):
        if not return_loss:
            return coarse_logits, fine_logits

        coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits))
        coarse_logits, fine_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

        num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1]
        num_fine_logits = fine_logits.shape[-1]

        num_coarse_logits = 0
        coarse_loss = 0.
        if self.coarse_cross_entropy_loss_weight > 0:

        if self.coarse_cross_entropy_loss_weight > 0 and exists(coarse_logits):
            num_coarse_logits = coarse_logits.shape[-1]

            coarse_loss = F.cross_entropy(
                coarse_logits,
                coarse_labels
@@ -1499,25 +1513,30 @@ class AudioLM(nn.Module):
        *,
        batch_size = 1,
        text: Optional[List[str]] = None,
        text_embeds: Optional[torch.Tensor] = None,
        prime_wave = None,
        max_length = 2048,
        return_coarse_generated_wave = False,
        mask_out_generated_fine_tokens = False
    ):
        assert not (self.needs_text and not exists(text)), 'text needs to be passed in if one of the transformer requires conditioning'
        assert not (self.needs_text and (not exists(text) and not exists(text_embeds))), 'text needs to be passed in if one of the transformer requires conditioning'

        if self.needs_text:
            if exists(text):
                text_embeds = self.semantic.embed_text(texts)

        if exists(prime_wave):
            prime_wave = prime_wave.to(self.device)

        semantic_token_ids = self.semantic.generate(
            text = text if self.semantic_has_condition else None,
            text_embeds = text_embeds if self.semantic_has_condition else None,
            batch_size = batch_size,
            prime_wave = prime_wave,
            max_length = max_length
        )

        coarse_token_ids_or_recon_wave = self.coarse.generate(
            text = text if self.coarse_has_condition else None,
            text_embeds = text_embeds if self.coarse_has_condition else None,
            semantic_token_ids = semantic_token_ids,
            reconstruct_wave = return_coarse_generated_wave
        )
@@ -1526,7 +1545,7 @@ class AudioLM(nn.Module):
            return coarse_token_ids_or_recon_wave

        generated_wave = self.fine.generate(
            text = text if self.fine_has_condition else None,
            text_embeds = text_embeds if self.fine_has_condition else None,
            coarse_token_ids = coarse_token_ids_or_recon_wave,
            reconstruct_wave = True,
            mask_out_generated_fine_tokens = mask_out_generated_fine_tokens
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.7.9',
  version = '0.8.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',