Commit 2a39b318 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use new transformer

parent 4a9928d2
Loading
Loading
Loading
Loading
+12 −30
Original line number Diff line number Diff line
@@ -40,50 +40,32 @@ class CNNHead(nn.Module):
class SigTransformer(nn.Module):
    __model_name__ = 'transformer'

    def __init__(self, in_ch=3, n_cls=2, n_query=8, hidden=384, nlayers=3, dropout=0.1):
    def __init__(self, in_ch=3, n_cls=2, hidden=384, nlayers=5, dropout=0.1):
        super(SigTransformer, self).__init__()
        nhead = hidden // 64

        self.encoder = CNNHead(in_ch, hidden)
        self.head = CNNHead(in_ch, hidden)
        self.pos_encoder = PositionalEncoding(hidden, dropout)

        self.decoder = nn.Embedding(n_query, hidden)
        self.pos_decoder = PositionalEncoding(hidden, dropout)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden) * 0.02)

        self.transformer = nn.Transformer(
            d_model=hidden, nhead=nhead, num_encoder_layers=nlayers,
            num_decoder_layers=nlayers, dim_feedforward=hidden, dropout=dropout,
        )
        encoder_layer = nn.TransformerEncoderLayer(hidden, nhead, dim_feedforward=2048, dropout=dropout,
                                                   activation='gelu')
        encoder_norm = nn.LayerNorm(hidden)
        self.encoder = nn.TransformerEncoder(encoder_layer, nlayers, encoder_norm)
        self.fc_out = nn.Linear(hidden, n_cls)

        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)

    def forward(self, src):
        trg = self.decoder.weight

        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(trg.device)

        # src_pad_mask = self.make_len_mask(src)
        # trg_pad_mask = self.make_len_mask(trg)

        src = self.encoder(src)
        src = self.head(src)  # [N,B,emb]
        cls_tokens = self.cls_token.expand(-1, src.shape[1], -1)
        src = torch.cat((cls_tokens, src), dim=0)
        src = self.pos_encoder(src)

        trg = trg.unsqueeze(1).repeat(1, src.shape[1], 1)
        # trg = self.decoder(trg)
        trg = self.pos_decoder(trg)
        output = self.transformer(src, trg, tgt_mask=self.trg_mask).transpose(0, 1)  # [B,N,emb]
        output = self.encoder(src).transpose(0, 1)  # [B,N,emb]
        output = self.fc_out(output[:, 0, :])

        return output
@@ -91,6 +73,6 @@ class SigTransformer(nn.Module):

if __name__ == '__main__':
    transformer = SigTransformer()
    x = torch.randn(8, 3, 380)
    x = torch.randn(8, 3, 400)
    y = transformer(x)
    print(y.shape)