Commit 2f755aed authored by dzy7e's avatar dzy7e
Browse files

add LayerNorm to head, learn pos_encoder

parent 904cdff6
Loading
Loading
Loading
Loading
+17 −16
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from torch import nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
    def __init__(self, d_model, dropout=0.1, max_len=200):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

@@ -26,10 +26,11 @@ class CNNHead(nn.Module):
    def __init__(self, in_chans=1, embed_dim=768):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv1d(in_chans, embed_dim // 2, kernel_size=7, stride=2),
            nn.BatchNorm1d(embed_dim // 2),
            nn.SiLU(),
            nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2),
            #nn.Conv1d(in_chans, embed_dim // 2, kernel_size=7, stride=2),
            #nn.BatchNorm1d(embed_dim // 2),
            #nn.SiLU(),
            #nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2),
            nn.Conv1d(in_chans, embed_dim, kernel_size=2, stride=2),
        )

    def forward(self, x):  # x:[B,ch,N_seq]
@@ -40,33 +41,33 @@ class CNNHead(nn.Module):
class SigTransformer(nn.Module):
    __model_name__ = 'transformer'

    def __init__(self, in_ch=3, n_cls=2, hidden=384, nlayers=5, dropout=0.1):
    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=5, dropout=0.1):
        super(SigTransformer, self).__init__()
        nhead = hidden // 64

        self.head = CNNHead(in_ch, hidden)
        self.pos_encoder = PositionalEncoding(hidden, dropout)
        #self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.pos_embedding = nn.Parameter(torch.randn(90 + 1, 1, hidden))

        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden) * 0.02)

        encoder_layer = nn.TransformerEncoderLayer(hidden, nhead, dim_feedforward=2048, dropout=dropout,
                                                   activation='gelu')
        encoder_layer = nn.TransformerEncoderLayer(hidden, nhead, dim_feedforward=2048, dropout=dropout)
        encoder_norm = nn.LayerNorm(hidden)
        self.encoder = nn.TransformerEncoder(encoder_layer, nlayers, encoder_norm)
        self.fc_out = nn.Linear(hidden, n_cls)

        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Linear(hidden, n_cls)
        )

    def forward(self, src):
        src = self.head(src)  # [N,B,emb]
        cls_tokens = self.cls_token.expand(-1, src.shape[1], -1)
        src = torch.cat((cls_tokens, src), dim=0)
        src = self.pos_encoder(src)
        #src = self.pos_encoder(src)
        src += self.pos_embedding

        output = self.encoder(src).transpose(0, 1)  # [B,N,emb]
        output = self.fc_out(output[:, 0, :])
        output = self.mlp_head(output[:, 0, :])

        return output