Commit 4236b412 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add onnx for levit_d0.2

parent 42dda6fa
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -78,8 +78,9 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str):
_KNOWN_CKPTS: List[Tuple[str, str, int]] = [
    # ('monochrome-alexnet-480.ckpt', 'alexnet', 180),
    # ('monochrome-resnet18-480.ckpt', 'resnet18', 180),
    ('monochrome-transformer-480.ckpt', 'transformer', 180),
    # ('monochrome-transformer-480.ckpt', 'transformer', 180),
    # ('monochrome-resnet18-safe2-450.ckpt', 'resnet18', 180),
    ('monochrome-levit_d0.2-500.ckpt', 'levit', 180),
]


+5 −2
Original line number Diff line number Diff line
@@ -82,7 +82,9 @@ class ResNet(nn.Module):

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        print(out.shape)
        out = self.layer1(out)
        print(out.shape)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
@@ -139,10 +141,11 @@ if __name__ == '__main__':

    for resnet_class in [
        ResNet182D, ResNet342D,
        ResNet502D, ResNet1012D, ResNet1522D
        ResNet502D,
        # ResNet1012D, ResNet1522D
    ]:
        net = resnet_class()
        x = torch.randn(1, 3, 384, 384)
        x = torch.randn(4, 3, 160, 160)

        flops, params = profile(net, (x,))
        print(f'{resnet_class.__name__}:')