Loading audiolm_pytorch/audiolm_pytorch.py +4 −1 Original line number Diff line number Diff line Loading @@ -679,7 +679,7 @@ class CoarseTransformer(nn.Module): offsets = offsets[:, :coarse_token_ids.shape[-1]] coarse_token_ids = coarse_token_ids + offsets semantic_tokens = self.semantic_embedding(semantic_token_ids) semantic_tokens = get_embeds(self.semantic_embedding, semantic_token_ids) coarse_tokens = self.coarse_embedding(coarse_token_ids) semantic_seq_len = semantic_tokens.shape[1] Loading Loading @@ -1164,6 +1164,9 @@ class CoarseTransformerWrapper(nn.Module): with torch.no_grad(): text_embeds = self.transformer.embed_text(text, output_device = device) if self.unique_consecutive: semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value=self.pad_id) # initialize init_coarse_time_step = coarse_token_ids.shape[-1] Loading Loading
audiolm_pytorch/audiolm_pytorch.py +4 −1 Original line number Diff line number Diff line Loading @@ -679,7 +679,7 @@ class CoarseTransformer(nn.Module): offsets = offsets[:, :coarse_token_ids.shape[-1]] coarse_token_ids = coarse_token_ids + offsets semantic_tokens = self.semantic_embedding(semantic_token_ids) semantic_tokens = get_embeds(self.semantic_embedding, semantic_token_ids) coarse_tokens = self.coarse_embedding(coarse_token_ids) semantic_seq_len = semantic_tokens.shape[1] Loading Loading @@ -1164,6 +1164,9 @@ class CoarseTransformerWrapper(nn.Module): with torch.no_grad(): text_embeds = self.transformer.embed_text(text, output_device = device) if self.unique_consecutive: semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value=self.pad_id) # initialize init_coarse_time_step = coarse_token_ids.shape[-1] Loading