Loading modules/ocr/mit48px.py +2 −2 Original line number Diff line number Diff line Loading @@ -328,10 +328,10 @@ def transformer_encoder_forward( is_causal: bool = False) -> torch.Tensor: x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal) x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x Loading Loading
modules/ocr/mit48px.py +2 −2 Original line number Diff line number Diff line Loading @@ -328,10 +328,10 @@ def transformer_encoder_forward( is_causal: bool = False) -> torch.Tensor: x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal) x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x Loading