Unverified Commit bf6dda9e authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #73 from zhvng/batch_unique

batch unique consecutive in CoarseTransformer
parents 09f453fc ed630be2
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -679,7 +679,7 @@ class CoarseTransformer(nn.Module):
        offsets = offsets[:, :coarse_token_ids.shape[-1]]
        coarse_token_ids = coarse_token_ids + offsets

        semantic_tokens = self.semantic_embedding(semantic_token_ids)
        semantic_tokens = get_embeds(self.semantic_embedding, semantic_token_ids)
        coarse_tokens = self.coarse_embedding(coarse_token_ids)

        semantic_seq_len = semantic_tokens.shape[1]
@@ -1164,6 +1164,9 @@ class CoarseTransformerWrapper(nn.Module):
            with torch.no_grad():
                text_embeds = self.transformer.embed_text(text, output_device = device)

        if self.unique_consecutive:
            semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value=self.pad_id)

        # initialize

        init_coarse_time_step = coarse_token_ids.shape[-1]