Commit 6874fbd6 authored by zhvng's avatar zhvng
Browse files

fix generate functions

parent 6a81c3a3
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -1174,7 +1174,7 @@ class CoarseTransformerWrapper(nn.Module):
                is_last_step = ind == (self.num_coarse_quantizers - 1)

                _, coarse_logits = self.transformer.forward_with_cond_scale(
                    coarse_token_ids = coarse_token_ids,
                    coarse_token_ids=sampled_coarse_token_ids,
                    semantic_token_ids = semantic_token_ids,
                    text_embeds = text_embeds,
                    cond_scale = cond_scale,
@@ -1388,7 +1388,7 @@ class FineTransformerWrapper(nn.Module):

                _, fine_logits = self.transformer.forward_with_cond_scale(
                    coarse_token_ids = coarse_token_ids,
                    fine_token_ids = fine_token_ids,
                    fine_token_ids = sampled_fine_token_ids,
                    text_embeds = text_embeds,
                    cond_scale = cond_scale,
                    return_only_fine_logits = True,