Commit 0406b106 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update cnns

parent a1fd0c77
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import torch.nn as nn
class MonochromeAlexNet(nn.Module):
    __model_name__ = 'alexnet'

    def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 7):
    def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 4):
        super(MonochromeAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(input_channels, 96, kernel_size=11, stride=4, padding=2),
+2 −2
Original line number Diff line number Diff line
@@ -71,7 +71,7 @@ class Bottleneck(nn.Module):


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 8):
    def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 23):
        super(ResNet, self).__init__()
        self.in_planes = 64

@@ -142,5 +142,5 @@ class ResNet152(ResNet):

if __name__ == '__main__':
    net = ResNet50(2)
    y = net(torch.randn(10, 3, 400))
    y = net(torch.randn(10, 3, 180))
    print(y.shape)