Commit 3da99089 authored by Leon Wu's avatar Leon Wu
Browse files

Use correct eos id when masking out

parent 035c0ce9
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -1135,7 +1135,7 @@ class SemanticTransformerWrapper(nn.Module):

            last_logit_indices += 1

        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, keep_eos = False)
        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.eos_id, keep_eos = False)

        return sample_semantic_ids