Loading zoo/monochrome/alexnet.py +1 −1 Original line number Diff line number Diff line Loading @@ -5,7 +5,7 @@ import torch.nn as nn class MonochromeAlexNet(nn.Module): __model_name__ = 'alexnet' def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 7): def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 4): super(MonochromeAlexNet, self).__init__() self.features = nn.Sequential( nn.Conv1d(input_channels, 96, kernel_size=11, stride=4, padding=2), Loading zoo/monochrome/resnet.py +2 −2 Original line number Diff line number Diff line Loading @@ -71,7 +71,7 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 8): def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 23): super(ResNet, self).__init__() self.in_planes = 64 Loading Loading @@ -142,5 +142,5 @@ class ResNet152(ResNet): if __name__ == '__main__': net = ResNet50(2) y = net(torch.randn(10, 3, 400)) y = net(torch.randn(10, 3, 180)) print(y.shape) Loading
zoo/monochrome/alexnet.py +1 −1 Original line number Diff line number Diff line Loading @@ -5,7 +5,7 @@ import torch.nn as nn class MonochromeAlexNet(nn.Module): __model_name__ = 'alexnet' def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 7): def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 4): super(MonochromeAlexNet, self).__init__() self.features = nn.Sequential( nn.Conv1d(input_channels, 96, kernel_size=11, stride=4, padding=2), Loading
zoo/monochrome/resnet.py +2 −2 Original line number Diff line number Diff line Loading @@ -71,7 +71,7 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 8): def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 23): super(ResNet, self).__init__() self.in_planes = 64 Loading Loading @@ -142,5 +142,5 @@ class ResNet152(ResNet): if __name__ == '__main__': net = ResNet50(2) y = net(torch.randn(10, 3, 400)) y = net(torch.randn(10, 3, 180)) print(y.shape)