Loading zoo/monochrome/transformer.py +12 −30 Original line number Diff line number Diff line Loading @@ -40,50 +40,32 @@ class CNNHead(nn.Module): class SigTransformer(nn.Module): __model_name__ = 'transformer' def __init__(self, in_ch=3, n_cls=2, n_query=8, hidden=384, nlayers=3, dropout=0.1): def __init__(self, in_ch=3, n_cls=2, hidden=384, nlayers=5, dropout=0.1): super(SigTransformer, self).__init__() nhead = hidden // 64 self.encoder = CNNHead(in_ch, hidden) self.head = CNNHead(in_ch, hidden) self.pos_encoder = PositionalEncoding(hidden, dropout) self.decoder = nn.Embedding(n_query, hidden) self.pos_decoder = PositionalEncoding(hidden, dropout) self.cls_token = nn.Parameter(torch.randn(1, 1, hidden) * 0.02) self.transformer = nn.Transformer( d_model=hidden, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dim_feedforward=hidden, dropout=dropout, ) encoder_layer = nn.TransformerEncoderLayer(hidden, nhead, dim_feedforward=2048, dropout=dropout, activation='gelu') encoder_norm = nn.LayerNorm(hidden) self.encoder = nn.TransformerEncoder(encoder_layer, nlayers, encoder_norm) self.fc_out = nn.Linear(hidden, n_cls) self.src_mask = None self.trg_mask = None self.memory_mask = None def generate_square_subsequent_mask(self, sz): mask = torch.triu(torch.ones(sz, sz), 1) mask = mask.masked_fill(mask == 1, float('-inf')) return mask def make_len_mask(self, inp): return (inp == 0).transpose(0, 1) def forward(self, src): trg = self.decoder.weight if self.trg_mask is None or self.trg_mask.size(0) != len(trg): self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(trg.device) # src_pad_mask = self.make_len_mask(src) # trg_pad_mask = self.make_len_mask(trg) src = self.encoder(src) src = self.head(src) # [N,B,emb] cls_tokens = self.cls_token.expand(-1, src.shape[1], -1) src = torch.cat((cls_tokens, src), dim=0) src = self.pos_encoder(src) trg = trg.unsqueeze(1).repeat(1, src.shape[1], 1) # trg = self.decoder(trg) trg = self.pos_decoder(trg) output = self.transformer(src, trg, tgt_mask=self.trg_mask).transpose(0, 1) # [B,N,emb] output = self.encoder(src).transpose(0, 1) # [B,N,emb] output = self.fc_out(output[:, 0, :]) return output Loading @@ -91,6 +73,6 @@ class SigTransformer(nn.Module): if __name__ == '__main__': transformer = SigTransformer() x = torch.randn(8, 3, 380) x = torch.randn(8, 3, 400) y = transformer(x) print(y.shape) Loading
zoo/monochrome/transformer.py +12 −30 Original line number Diff line number Diff line Loading @@ -40,50 +40,32 @@ class CNNHead(nn.Module): class SigTransformer(nn.Module): __model_name__ = 'transformer' def __init__(self, in_ch=3, n_cls=2, n_query=8, hidden=384, nlayers=3, dropout=0.1): def __init__(self, in_ch=3, n_cls=2, hidden=384, nlayers=5, dropout=0.1): super(SigTransformer, self).__init__() nhead = hidden // 64 self.encoder = CNNHead(in_ch, hidden) self.head = CNNHead(in_ch, hidden) self.pos_encoder = PositionalEncoding(hidden, dropout) self.decoder = nn.Embedding(n_query, hidden) self.pos_decoder = PositionalEncoding(hidden, dropout) self.cls_token = nn.Parameter(torch.randn(1, 1, hidden) * 0.02) self.transformer = nn.Transformer( d_model=hidden, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dim_feedforward=hidden, dropout=dropout, ) encoder_layer = nn.TransformerEncoderLayer(hidden, nhead, dim_feedforward=2048, dropout=dropout, activation='gelu') encoder_norm = nn.LayerNorm(hidden) self.encoder = nn.TransformerEncoder(encoder_layer, nlayers, encoder_norm) self.fc_out = nn.Linear(hidden, n_cls) self.src_mask = None self.trg_mask = None self.memory_mask = None def generate_square_subsequent_mask(self, sz): mask = torch.triu(torch.ones(sz, sz), 1) mask = mask.masked_fill(mask == 1, float('-inf')) return mask def make_len_mask(self, inp): return (inp == 0).transpose(0, 1) def forward(self, src): trg = self.decoder.weight if self.trg_mask is None or self.trg_mask.size(0) != len(trg): self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(trg.device) # src_pad_mask = self.make_len_mask(src) # trg_pad_mask = self.make_len_mask(trg) src = self.encoder(src) src = self.head(src) # [N,B,emb] cls_tokens = self.cls_token.expand(-1, src.shape[1], -1) src = torch.cat((cls_tokens, src), dim=0) src = self.pos_encoder(src) trg = trg.unsqueeze(1).repeat(1, src.shape[1], 1) # trg = self.decoder(trg) trg = self.pos_decoder(trg) output = self.transformer(src, trg, tgt_mask=self.trg_mask).transpose(0, 1) # [B,N,emb] output = self.encoder(src).transpose(0, 1) # [B,N,emb] output = self.fc_out(output[:, 0, :]) return output Loading @@ -91,6 +73,6 @@ class SigTransformer(nn.Module): if __name__ == '__main__': transformer = SigTransformer() x = torch.randn(8, 3, 380) x = torch.randn(8, 3, 400) y = transformer(x) print(y.shape)