Commit dfd4b803 authored by Phil Wang's avatar Phil Wang
Browse files
parent e0a4fd94
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -8,3 +8,5 @@ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer, CoarseTransformerTrainer

from audiolm_pytorch.audiolm_pytorch import get_embeds
+24 −2
Original line number Diff line number Diff line
@@ -111,6 +111,28 @@ def batch_unique_consecutive(t, pad_value = 0.):
    unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)]
    return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value)

# function for getting embeds from nn.Embedding but with padding as some designated value (-1) outside the range of the embed table

@beartype
def get_embeds(
    embeddings: nn.Embedding,
    codes: torch.Tensor,
    pad_id = -1,
    return_mask = False,
    mask_pad_pos_to = 0
):
    pad_mask = codes == pad_id
    codes_without_pad = codes.masked_fill(pad_mask, 0) # just retrieve first code as dummy
    embeds = embeddings(codes_without_pad)

    if exists(mask_pad_pos_to):
        embeds = embeds.masked_fill(rearrange(pad_mask, '... -> ... 1'), mask_pad_pos_to)

    if return_mask:
        return embeds, ~pad_mask

    return embeds

# relative positional bias

class RelativePositionBias(nn.Module):
@@ -782,7 +804,6 @@ class SemanticTransformerWrapper(nn.Module):
        start_length = ids.shape[-1]
        sample_semantic_ids = ids.clone()

        batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1')
        last_logit_indices = (ids != self.pad_id).sum(dim = -1).long()

        # sample from transformer
@@ -795,7 +816,8 @@ class SemanticTransformerWrapper(nn.Module):
                **kwargs
            )

            last_logits = logits[batch_range, last_logit_indices]
            last_logit_indices_expanded = repeat(last_logit_indices, 'b -> b 1 c', b = batch, c = logits.shape[-1])
            last_logits = logits.gather(1, last_logit_indices_expanded)

            last_logits = rearrange(last_logits, 'b 1 c -> b c')

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.61',
  version = '0.0.62',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',