Commit 657f18af authored by dmMaze's avatar dmMaze
Browse files

fix #143

parent e3cddfdf
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -85,7 +85,7 @@ class CustomTransformerEncoderLayer(nn.Module):
            state['activation'] = F.relu
        super(CustomTransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, is_causal = None) -> torch.Tensor:
        r"""Pass the input through the encoder layer.
        Args:
            src: the sequence to the encoder layer (required).