Commit ed630be2 authored by zhvng's avatar zhvng
Browse files

batch unique consecutive in CoarseTransformerWrapper generate

parent 6a81c3a3
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -679,7 +679,7 @@ class CoarseTransformer(nn.Module):
        offsets = offsets[:, :coarse_token_ids.shape[-1]]
        coarse_token_ids = coarse_token_ids + offsets

        semantic_tokens = self.semantic_embedding(semantic_token_ids)
        semantic_tokens = get_embeds(self.semantic_embedding, semantic_token_ids)
        coarse_tokens = self.coarse_embedding(coarse_token_ids)

        semantic_seq_len = semantic_tokens.shape[1]
@@ -1164,6 +1164,9 @@ class CoarseTransformerWrapper(nn.Module):
            with torch.no_grad():
                text_embeds = self.transformer.embed_text(text, output_device = device)

        if self.unique_consecutive:
            semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value=self.pad_id)

        # initialize

        init_coarse_time_step = coarse_token_ids.shape[-1]