Loading audiolm_pytorch/__init__.py +2 −0 Original line number Diff line number Diff line Loading @@ -8,3 +8,5 @@ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer, CoarseTransformerTrainer from audiolm_pytorch.audiolm_pytorch import get_embeds audiolm_pytorch/audiolm_pytorch.py +24 −2 Original line number Diff line number Diff line Loading @@ -111,6 +111,28 @@ def batch_unique_consecutive(t, pad_value = 0.): unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)] return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value) # function for getting embeds from nn.Embedding but with padding as some designated value (-1) outside the range of the embed table @beartype def get_embeds( embeddings: nn.Embedding, codes: torch.Tensor, pad_id = -1, return_mask = False, mask_pad_pos_to = 0 ): pad_mask = codes == pad_id codes_without_pad = codes.masked_fill(pad_mask, 0) # just retrieve first code as dummy embeds = embeddings(codes_without_pad) if exists(mask_pad_pos_to): embeds = embeds.masked_fill(rearrange(pad_mask, '... -> ... 1'), mask_pad_pos_to) if return_mask: return embeds, ~pad_mask return embeds # relative positional bias class RelativePositionBias(nn.Module): Loading Loading @@ -782,7 +804,6 @@ class SemanticTransformerWrapper(nn.Module): start_length = ids.shape[-1] sample_semantic_ids = ids.clone() batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1') last_logit_indices = (ids != self.pad_id).sum(dim = -1).long() # sample from transformer Loading @@ -795,7 +816,8 @@ class SemanticTransformerWrapper(nn.Module): **kwargs ) last_logits = logits[batch_range, last_logit_indices] last_logit_indices_expanded = repeat(last_logit_indices, 'b -> b 1 c', b = batch, c = logits.shape[-1]) last_logits = logits.gather(1, last_logit_indices_expanded) last_logits = rearrange(last_logits, 'b 1 c -> b c') Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.61', version = '0.0.62', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/__init__.py +2 −0 Original line number Diff line number Diff line Loading @@ -8,3 +8,5 @@ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer, CoarseTransformerTrainer from audiolm_pytorch.audiolm_pytorch import get_embeds
audiolm_pytorch/audiolm_pytorch.py +24 −2 Original line number Diff line number Diff line Loading @@ -111,6 +111,28 @@ def batch_unique_consecutive(t, pad_value = 0.): unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)] return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value) # function for getting embeds from nn.Embedding but with padding as some designated value (-1) outside the range of the embed table @beartype def get_embeds( embeddings: nn.Embedding, codes: torch.Tensor, pad_id = -1, return_mask = False, mask_pad_pos_to = 0 ): pad_mask = codes == pad_id codes_without_pad = codes.masked_fill(pad_mask, 0) # just retrieve first code as dummy embeds = embeddings(codes_without_pad) if exists(mask_pad_pos_to): embeds = embeds.masked_fill(rearrange(pad_mask, '... -> ... 1'), mask_pad_pos_to) if return_mask: return embeds, ~pad_mask return embeds # relative positional bias class RelativePositionBias(nn.Module): Loading Loading @@ -782,7 +804,6 @@ class SemanticTransformerWrapper(nn.Module): start_length = ids.shape[-1] sample_semantic_ids = ids.clone() batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1') last_logit_indices = (ids != self.pad_id).sum(dim = -1).long() # sample from transformer Loading @@ -795,7 +816,8 @@ class SemanticTransformerWrapper(nn.Module): **kwargs ) last_logits = logits[batch_range, last_logit_indices] last_logit_indices_expanded = repeat(last_logit_indices, 'b -> b 1 c', b = batch, c = logits.shape[-1]) last_logits = logits.gather(1, last_logit_indices_expanded) last_logits = rearrange(last_logits, 'b 1 c -> b c') Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.61', version = '0.0.62', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading