Commit 684da45c authored by Phil Wang's avatar Phil Wang
Browse files

add some preliminary generation code for coarse transformer, primed on semantic tokens

parent a18a0cab
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -107,6 +107,16 @@ loss = coarse_wrapper(
)

loss.backward()

# after a lot of training

mock_semantic_token_ids = torch.randint(0, wav2vec.codebook_size, (1, 128))

coarse_tokens = coarse_wrapper.generate(
    semantic_token_ids = mock_semantic_token_ids,
    max_time_steps = 512
) # (1, 512, 3) - (batch, time steps, num quantizers)

```

ex. `FineTransformer`
+65 −3
Original line number Diff line number Diff line
import math
from functools import partial

from typing import Optional, Union
from typeguard import typechecked

import torch
from torch import nn, einsum
@@ -305,6 +307,7 @@ class Transformer(nn.Module):

# the three hierarchical transformers

@typechecked
class SemanticTransformer(nn.Module):
    def __init__(
        self,
@@ -518,6 +521,7 @@ class SemanticTransformer(nn.Module):

        return loss

@typechecked
class CoarseTransformer(nn.Module):
    def __init__(
        self,
@@ -589,7 +593,8 @@ class CoarseTransformer(nn.Module):
        self_attn_mask = None,
        text = None,
        text_embeds = None,
        cond_drop_prob = None
        cond_drop_prob = None,
        return_only_coarse_logits = False
    ):
        b, device = semantic_token_ids.shape[0], semantic_token_ids.device

@@ -638,7 +643,7 @@ class CoarseTransformer(nn.Module):

        # semantic logits

        semantic_logits = self.to_semantic_logits(pred_semantic_tokens)
        semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits else None

        # get coarse logits

@@ -821,7 +826,7 @@ class FineTransformer(nn.Module):

# training wrappers


@typechecked
class CoarseTransformerWrapper(nn.Module):
    def __init__(
        self,
@@ -841,6 +846,61 @@ class CoarseTransformerWrapper(nn.Module):
        self.pad_id = pad_id

        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        self.eos_id = transformer.coarse_eos_id

    @property
    def device(self):
        return next(self.parameters()).device

    @eval_decorator
    @torch.no_grad()
    def generate(
        self,
        *,
        semantic_token_ids,
        max_time_steps = 512,
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        reshape_output = True
    ):
        batch, device = semantic_token_ids.shape[0], self.device

        semantic_token_ids = semantic_token_ids.to(device)

        coarse_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long)

        init_coarse_time_step = coarse_token_ids.shape[-1]
        output = coarse_token_ids.clone()

        for time_step in range(init_coarse_time_step, max_time_steps):
            for ind in range(self.num_coarse_quantizers):
                is_last_step = ind == (self.num_coarse_quantizers - 1)

                _, coarse_logits = self.transformer.forward_with_cond_scale(
                    coarse_token_ids = coarse_token_ids,
                    semantic_token_ids = semantic_token_ids,
                    cond_scale = cond_scale,
                    return_only_coarse_logits = True
                )

                last_coarse_logits = coarse_logits[:, -1]

                if not is_last_step:
                    last_coarse_logits[:, -1] = float('-inf') # prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval

                filtered_logits = top_k(last_coarse_logits, thres = filter_thres)
                sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

                sampled = rearrange(sampled, 'b -> b 1')
                output = torch.cat((output, sampled), dim = -1)

        output = mask_out_after_eos_id(output, self.eos_id, include_eos = False)

        if reshape_output:
            output = rearrange(output, 'b (n q) -> b n q', q = self.num_coarse_quantizers)

        return output

    def forward(
        self,
@@ -918,6 +978,7 @@ class CoarseTransformerWrapper(nn.Module):

        return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits)

@typechecked
class FineTransformerWrapper(nn.Module):
    def __init__(
        self,
@@ -999,6 +1060,7 @@ class FineTransformerWrapper(nn.Module):

# audio LM

@typechecked
class AudioLM(nn.Module):
    def __init__(
        self,
+2 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.39',
  version = '0.0.40',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -26,6 +26,7 @@ setup(
    'torch>=1.6',
    'torchaudio',
    'transformers',
    'typeguard',
    'vector-quantize-pytorch>=0.10.10'
  ],
  classifiers=[