Commit 248f260f authored by dmMaze's avatar dmMaze
Browse files

fix #371

parent 704aca60
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -328,10 +328,10 @@ def transformer_encoder_forward(
    is_causal: bool = False) -> torch.Tensor:
    x = src
    if self.norm_first:
        x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
        x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
        x = x + self._ff_block(self.norm2(x))
    else:
        x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
        x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
        x = self.norm2(x + self._ff_block(x))

    return x