@@ -1313,7 +1311,8 @@ class CoarseTransformerWrapper(nn.Module):
coarse_loss=F.cross_entropy(
coarse_logits,
coarse_labels
coarse_labels,
ignore_index=self.pad_id
)
return (
@@ -1415,9 +1414,6 @@ class FineTransformerWrapper(nn.Module):
last_fine_logits=fine_logits[:,-1]
ifnotis_last_step:
last_fine_logits[:,-1]=float('-inf')# prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval