Commit 0a248000 authored by dzy7e's avatar dzy7e
Browse files

add LayerNorm to head, learn pos_encoder

parent 2f755aed
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -41,13 +41,13 @@ class CNNHead(nn.Module):
class SigTransformer(nn.Module):
    __model_name__ = 'transformer'

    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=5, dropout=0.1):
    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=5, dropout=0.1, seq_len=90):
        super(SigTransformer, self).__init__()
        nhead = hidden // 64

        self.head = CNNHead(in_ch, hidden)
        #self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.pos_embedding = nn.Parameter(torch.randn(90 + 1, 1, hidden))
        self.pos_embedding = nn.Parameter(torch.randn(seq_len + 1, 1, hidden))

        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden) * 0.02)