Loading ballontranslator/dl/ocr/mit48px_ctc.py +1 −1 Original line number Diff line number Diff line Loading @@ -85,7 +85,7 @@ class CustomTransformerEncoderLayer(nn.Module): state['activation'] = F.relu super(CustomTransformerEncoderLayer, self).__setstate__(state) def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, is_causal = None) -> torch.Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). Loading Loading
ballontranslator/dl/ocr/mit48px_ctc.py +1 −1 Original line number Diff line number Diff line Loading @@ -85,7 +85,7 @@ class CustomTransformerEncoderLayer(nn.Module): state['activation'] = F.relu super(CustomTransformerEncoderLayer, self).__setstate__(state) def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, is_causal = None) -> torch.Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). Loading