Commit 4a9928d2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add resnet and transformer

parent f4dab666
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@ import torch.nn as nn


class MonochromeAlexNet(nn.Module):
    __model_name__ = 'alexnet'

    def __init__(self, input_channels: int = 3, num_classes=2):
        super(MonochromeAlexNet, self).__init__()
        self.features = nn.Sequential(
+146 −0
Original line number Diff line number Diff line
'''ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv1d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.conv3 = nn.Conv1d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm1d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 6):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv1d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool1d(avgpool_size)
        self.linear = nn.Linear(512 * block.expansion * avgpool_size, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


class ResNet18(ResNet):
    __model_name__ = 'resnet18'

    def __init__(self, num_classes: int = 2):
        ResNet.__init__(self, BasicBlock, [2, 2, 2, 2], num_classes)


class ResNet34(ResNet):
    __model_name__ = 'resnet34'

    def __init__(self, num_classes: int = 2):
        ResNet.__init__(self, BasicBlock, [3, 4, 6, 3], num_classes)


class ResNet50(ResNet):
    __model_name__ = 'resnet50'

    def __init__(self, num_classes: int = 2):
        ResNet.__init__(self, Bottleneck, [3, 4, 6, 3], num_classes)


class ResNet101(ResNet):
    __model_name__ = 'resnet101'

    def __init__(self, num_classes: int = 2):
        ResNet.__init__(self, Bottleneck, [3, 4, 23, 3], num_classes)


class ResNet152(ResNet):
    __model_name__ = 'resnet152'

    def __init__(self, num_classes: int = 2):
        ResNet.__init__(self, Bottleneck, [3, 8, 36, 3], num_classes)


if __name__ == '__main__':
    net = ResNet50(2)
    y = net(torch.randn(10, 3, 400))
    print(y.shape)
+31 −10
Original line number Diff line number Diff line
import os.path
import re
from typing import Optional
from functools import partial
from typing import Optional, Type

import torch
from ditk import logging
@@ -11,21 +12,39 @@ from tqdm.auto import tqdm

from .alexnet import MonochromeAlexNet
from .dataset import MonochromeDataset
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
_LOG_DIR = os.path.join(_TRAIN_DIR, 'logs')
_CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts')

_CKPT_PATTERN = re.compile(r'^monochrome-(?P<epoch>\d+)\.ckpt$')
_CKPT_PATTERN = re.compile(r'^monochrome-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$')

_KNOWN_MODELS = {}

def _find_latest_ckpt() -> Optional[str]:

def _register_model(cls: Type[nn.Module], *args, name=None, **kwargs):
    name = name or cls.__model_name__
    _KNOWN_MODELS[name] = partial(cls, *args, **kwargs)


_register_model(MonochromeAlexNet)
_register_model(ResNet18)
_register_model(ResNet34)
_register_model(ResNet50)
_register_model(ResNet101)
_register_model(ResNet152)
_register_model(SigTransformer)


def _find_latest_ckpt(name: str) -> Optional[str]:
    if os.path.exists(_CKPT_DIR):
        ckpts = []
        for filename in os.listdir(_CKPT_DIR):
            matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename))
            if matching:
            if matching and matching.group('name') == name:
                ckpts.append((int(matching.group('epoch')), os.path.join(_CKPT_DIR, filename)))

        ckpts = sorted(ckpts)
@@ -50,10 +69,11 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:

def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float = 0.8,
          batch_size: int = 4, feature_bins: int = 400, max_epochs: int = 500, learning_rate: float = 0.001,
          save_per_epoch: int = 10):
    os.makedirs(_LOG_DIR, exist_ok=True)
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
    _log_dir = os.path.join(_LOG_DIR, model_name)
    os.makedirs(_log_dir, exist_ok=True)
    os.makedirs(_CKPT_DIR, exist_ok=True)
    writer = SummaryWriter(_LOG_DIR)
    writer = SummaryWriter(_log_dir)
    writer.add_custom_scalars({
        "general": {
            "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]],
@@ -72,9 +92,10 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

    # Load previous epoch
    model = MonochromeAlexNet().float()
    model = _KNOWN_MODELS[model_name]().float()
    # model = MonochromeAlexNet().float()
    if from_ckpt is None:
        from_ckpt = _find_latest_ckpt()
        from_ckpt = _find_latest_ckpt(model_name)
    previous_epoch = _ckpt_epoch(from_ckpt) or 0
    if from_ckpt:
        logging.info(f'Load checkpoint from {from_ckpt!r}.')
@@ -138,6 +159,6 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
            writer.add_scalar('test/accuracy', test_accuracy, epoch)

        if epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{epoch}.ckpt')
            current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{model_name}-{epoch}.ckpt')
            torch.save(model.state_dict(), current_ckpt_file)
            logging.info(f'Saved to {current_ckpt_file!r}.')
+96 −0
Original line number Diff line number Diff line
import math

import torch
from torch import nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class CNNHead(nn.Module):
    def __init__(self, in_chans=1, embed_dim=768):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv1d(in_chans, embed_dim // 2, kernel_size=7, stride=2),
            nn.BatchNorm1d(embed_dim // 2),
            nn.SiLU(),
            nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2),
        )

    def forward(self, x):  # x:[B,ch,N_seq]
        x = self.proj(x).permute(2, 0, 1)
        return x


class SigTransformer(nn.Module):
    __model_name__ = 'transformer'

    def __init__(self, in_ch=3, n_cls=2, n_query=8, hidden=384, nlayers=3, dropout=0.1):
        super(SigTransformer, self).__init__()
        nhead = hidden // 64

        self.encoder = CNNHead(in_ch, hidden)
        self.pos_encoder = PositionalEncoding(hidden, dropout)

        self.decoder = nn.Embedding(n_query, hidden)
        self.pos_decoder = PositionalEncoding(hidden, dropout)

        self.transformer = nn.Transformer(
            d_model=hidden, nhead=nhead, num_encoder_layers=nlayers,
            num_decoder_layers=nlayers, dim_feedforward=hidden, dropout=dropout,
        )
        self.fc_out = nn.Linear(hidden, n_cls)

        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)

    def forward(self, src):
        trg = self.decoder.weight

        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(trg.device)

        # src_pad_mask = self.make_len_mask(src)
        # trg_pad_mask = self.make_len_mask(trg)

        src = self.encoder(src)
        src = self.pos_encoder(src)

        trg = trg.unsqueeze(1).repeat(1, src.shape[1], 1)
        # trg = self.decoder(trg)
        trg = self.pos_decoder(trg)
        output = self.transformer(src, trg, tgt_mask=self.trg_mask).transpose(0, 1)  # [B,N,emb]
        output = self.fc_out(output[:, 0, :])

        return output


if __name__ == '__main__':
    transformer = SigTransformer()
    x = torch.randn(8, 3, 380)
    y = transformer(x)
    print(y.shape)