Commit ea360b82 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add resnet2d

parent 2d37ce0b
Loading
Loading
Loading
Loading
+0 −8
Original line number Diff line number Diff line
'''ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
+156 −0
Original line number Diff line number Diff line
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=2, avgpool_size: Tuple[int, int] = (16, 16)):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(avgpool_size)
        self.linear = nn.Linear(512 * block.expansion * avgpool_size[0] * avgpool_size[1], num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        print(out.shape)
        out = self.avgpool(out)
        print(out.shape)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


class ResNet182D(ResNet):
    __model_name__ = 'resnet18_2d'
    __dims__ = 2

    def __init__(self):
        ResNet.__init__(self, BasicBlock, [2, 2, 2, 2])


class ResNet342D(ResNet):
    __model_name__ = 'resnet34_2d'
    __dims__ = 2

    def __init__(self):
        ResNet.__init__(self, BasicBlock, [3, 4, 6, 3])


class ResNet502D(ResNet):
    __model_name__ = 'resnet50_2d'
    __dims__ = 2

    def __init__(self):
        ResNet.__init__(self, Bottleneck, [3, 4, 6, 3])


class ResNet1012D(ResNet):
    __model_name__ = 'resnet101_2d'
    __dims__ = 2

    def __init__(self):
        ResNet.__init__(self, Bottleneck, [3, 4, 23, 3])


class ResNet1522D(ResNet):
    __model_name__ = 'resnet152_2d'
    __dims__ = 2

    def __init__(self):
        ResNet.__init__(self, Bottleneck, [3, 8, 36, 3])


if __name__ == '__main__':
    from thop import profile

    for resnet_class in [
        ResNet182D, ResNet342D,
        ResNet502D, ResNet1012D, ResNet1522D
    ]:
        net = resnet_class()
        x = torch.randn(1, 3, 384, 384)

        flops, params = profile(net, (x,))
        print(f'{resnet_class.__name__}:')
        print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
        print('Params = ' + str(params / 1000 ** 2) + 'M')
+7 −1
Original line number Diff line number Diff line
@@ -17,7 +17,8 @@ from .dataset import MonochromeDataset, Monochrome2DDataset, random_split_datase
from .levit1d import LeSigTransformer
from .levit2d import LeViT
from .loss import FocalLoss
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .resnet1d import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .resnet2d import ResNet182D, ResNet342D, ResNet502D, ResNet1012D, ResNet1522D
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

@@ -44,6 +45,11 @@ _register_model(ResNet34)
_register_model(ResNet50)
_register_model(ResNet101)
_register_model(ResNet152)
_register_model(ResNet182D)
_register_model(ResNet342D)
_register_model(ResNet502D)
_register_model(ResNet1012D)
_register_model(ResNet1522D)
_register_model(SigTransformer)
_register_model(LeSigTransformer)
_register_model(LeViT)