Loading audiolm_pytorch/audiolm_pytorch.py +26 −14 Original line number Diff line number Diff line Loading @@ -18,6 +18,8 @@ from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME from torchaudio.functional import resample # helper functions def exists(val): Loading Loading @@ -295,9 +297,12 @@ class SoundStream(nn.Module): adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout = True, quantize_dropout_cutoff_index = 0 quantize_dropout_cutoff_index = 0, target_sample_khz = 24000 ): super().__init__() self.target_sample_khz = target_sample_khz # for resampling on the fly self.single_channel = input_channels == 1 self.strides = strides Loading Loading @@ -363,8 +368,12 @@ class SoundStream(nn.Module): return_encoded = False, return_discr_loss = False, return_discr_losses_separately = False, return_recons_only = False return_recons_only = False, input_sample_khz = None ): if exists(input_sample_khz): x = resample(x, input_sample_khz, self.target_sample_khz) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading Loading @@ -699,7 +708,7 @@ class SemanticTransformer(nn.Module): ids = None, return_loss = False, text = None, text_embed = None, text_embeds = None, cond_drop_prob = None ): device = next(self.parameters()).device Loading @@ -717,17 +726,18 @@ class SemanticTransformer(nn.Module): if self.unique_consecutive: ids = batch_unique_consecutive(ids, pad_value = self.pad_id) has_text = exists(text) or exists(text_embed) has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embed): text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if cond_drop_prob > 0: if exists(text_mask) and cond_drop_prob > 0: keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device) text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask Loading Loading @@ -798,22 +808,23 @@ class CoarseTransformer(nn.Module): coarse_token_ids, self_attn_mask = None, text = None, text_embed = None, text_embeds = None, cond_drop_prob = None ): b, device = semantic_token_ids.shape[0], semantic_token_ids.device has_text = exists(text) or exists(text_embed) has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embed): text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if cond_drop_prob > 0: if exists(text_mask) and cond_drop_prob > 0: keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device) text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask Loading Loading @@ -906,21 +917,22 @@ class FineTransformer(nn.Module): coarse_token_ids, fine_token_ids, text = None, text_embed = None, text_embeds = None, cond_drop_prob = None ): b, device = coarse_token_ids.shape[0], coarse_token_ids.device has_text = exists(text) or exists(text_embed) has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embed): text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if cond_drop_prob > 0: if exists(text_mask) and cond_drop_prob > 0: keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device) text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask Loading audiolm_pytorch/hubert_kmeans.py +18 −2 Original line number Diff line number Diff line Loading @@ -7,13 +7,21 @@ from einops import rearrange, pack, unpack import joblib import fairseq from torchaudio.functional import resample def exists(val): return val is not None class HubertWithKmeans(nn.Module): def __init__( self, checkpoint_path, kmeans_path kmeans_path, target_sample_khz = 50000 ): super().__init__() self.target_sample_khz = target_sample_khz model_path = Path(checkpoint_path) kmeans_path = Path(kmeans_path) Loading @@ -39,9 +47,17 @@ class HubertWithKmeans(nn.Module): return self.kmeans.n_clusters @torch.no_grad() def forward(self, wav_input, flatten = True): def forward( self, wav_input, flatten = True, input_sample_khz = None ): device = wav_input.device if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) embed = self.model(wav_input, features_only = True) embed, packed_shape = pack([embed['x']], '* d') Loading audiolm_pytorch/vq_wav2vec.py +18 −2 Original line number Diff line number Diff line Loading @@ -6,12 +6,20 @@ from einops import rearrange import fairseq from torchaudio.functional import resample def exists(val): return val is not None class FairseqVQWav2Vec(nn.Module): def __init__( self, checkpoint_path checkpoint_path, target_sample_khz = 24000 ): super().__init__() self.target_sample_khz = target_sample_khz path = Path(checkpoint_path) assert path.exists(), f'path {checkpoint_path} does not exist' Loading @@ -31,7 +39,15 @@ class FairseqVQWav2Vec(nn.Module): return self.model.vector_quantizer.embedding.shape[0] @torch.no_grad() def forward(self, wav_input, flatten = True): def forward( self, wav_input, flatten = True, input_sample_khz = None ): if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) embed = self.model.feature_extractor(wav_input) _, codebook_indices = self.model.vector_quantizer.forward_idx(embed) Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.20', version = '0.0.21', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/audiolm_pytorch.py +26 −14 Original line number Diff line number Diff line Loading @@ -18,6 +18,8 @@ from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME from torchaudio.functional import resample # helper functions def exists(val): Loading Loading @@ -295,9 +297,12 @@ class SoundStream(nn.Module): adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout = True, quantize_dropout_cutoff_index = 0 quantize_dropout_cutoff_index = 0, target_sample_khz = 24000 ): super().__init__() self.target_sample_khz = target_sample_khz # for resampling on the fly self.single_channel = input_channels == 1 self.strides = strides Loading Loading @@ -363,8 +368,12 @@ class SoundStream(nn.Module): return_encoded = False, return_discr_loss = False, return_discr_losses_separately = False, return_recons_only = False return_recons_only = False, input_sample_khz = None ): if exists(input_sample_khz): x = resample(x, input_sample_khz, self.target_sample_khz) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading Loading @@ -699,7 +708,7 @@ class SemanticTransformer(nn.Module): ids = None, return_loss = False, text = None, text_embed = None, text_embeds = None, cond_drop_prob = None ): device = next(self.parameters()).device Loading @@ -717,17 +726,18 @@ class SemanticTransformer(nn.Module): if self.unique_consecutive: ids = batch_unique_consecutive(ids, pad_value = self.pad_id) has_text = exists(text) or exists(text_embed) has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embed): text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if cond_drop_prob > 0: if exists(text_mask) and cond_drop_prob > 0: keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device) text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask Loading Loading @@ -798,22 +808,23 @@ class CoarseTransformer(nn.Module): coarse_token_ids, self_attn_mask = None, text = None, text_embed = None, text_embeds = None, cond_drop_prob = None ): b, device = semantic_token_ids.shape[0], semantic_token_ids.device has_text = exists(text) or exists(text_embed) has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embed): text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if cond_drop_prob > 0: if exists(text_mask) and cond_drop_prob > 0: keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device) text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask Loading Loading @@ -906,21 +917,22 @@ class FineTransformer(nn.Module): coarse_token_ids, fine_token_ids, text = None, text_embed = None, text_embeds = None, cond_drop_prob = None ): b, device = coarse_token_ids.shape[0], coarse_token_ids.device has_text = exists(text) or exists(text_embed) has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) if not exists(text_embed): text_mask = None if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.embed_text(text, output_device = device) text_mask = torch.any(text_embeds != 0, dim = -1) cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) if cond_drop_prob > 0: if exists(text_mask) and cond_drop_prob > 0: keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device) text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask Loading
audiolm_pytorch/hubert_kmeans.py +18 −2 Original line number Diff line number Diff line Loading @@ -7,13 +7,21 @@ from einops import rearrange, pack, unpack import joblib import fairseq from torchaudio.functional import resample def exists(val): return val is not None class HubertWithKmeans(nn.Module): def __init__( self, checkpoint_path, kmeans_path kmeans_path, target_sample_khz = 50000 ): super().__init__() self.target_sample_khz = target_sample_khz model_path = Path(checkpoint_path) kmeans_path = Path(kmeans_path) Loading @@ -39,9 +47,17 @@ class HubertWithKmeans(nn.Module): return self.kmeans.n_clusters @torch.no_grad() def forward(self, wav_input, flatten = True): def forward( self, wav_input, flatten = True, input_sample_khz = None ): device = wav_input.device if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) embed = self.model(wav_input, features_only = True) embed, packed_shape = pack([embed['x']], '* d') Loading
audiolm_pytorch/vq_wav2vec.py +18 −2 Original line number Diff line number Diff line Loading @@ -6,12 +6,20 @@ from einops import rearrange import fairseq from torchaudio.functional import resample def exists(val): return val is not None class FairseqVQWav2Vec(nn.Module): def __init__( self, checkpoint_path checkpoint_path, target_sample_khz = 24000 ): super().__init__() self.target_sample_khz = target_sample_khz path = Path(checkpoint_path) assert path.exists(), f'path {checkpoint_path} does not exist' Loading @@ -31,7 +39,15 @@ class FairseqVQWav2Vec(nn.Module): return self.model.vector_quantizer.embedding.shape[0] @torch.no_grad() def forward(self, wav_input, flatten = True): def forward( self, wav_input, flatten = True, input_sample_khz = None ): if exists(input_sample_khz): wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz) embed = self.model.feature_extractor(wav_input) _, codebook_indices = self.model.vector_quantizer.forward_idx(embed) Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.20', version = '0.0.21', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading