Loading zoo/monochrome/__main__.py +2 −1 Original line number Diff line number Diff line Loading @@ -78,8 +78,9 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str): _KNOWN_CKPTS: List[Tuple[str, str, int]] = [ # ('monochrome-alexnet-480.ckpt', 'alexnet', 180), # ('monochrome-resnet18-480.ckpt', 'resnet18', 180), ('monochrome-transformer-480.ckpt', 'transformer', 180), # ('monochrome-transformer-480.ckpt', 'transformer', 180), # ('monochrome-resnet18-safe2-450.ckpt', 'resnet18', 180), ('monochrome-levit_d0.2-500.ckpt', 'levit', 180), ] Loading zoo/monochrome/resnet2d.py +5 −2 Original line number Diff line number Diff line Loading @@ -82,7 +82,9 @@ class ResNet(nn.Module): def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) print(out.shape) out = self.layer1(out) print(out.shape) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) Loading Loading @@ -139,10 +141,11 @@ if __name__ == '__main__': for resnet_class in [ ResNet182D, ResNet342D, ResNet502D, ResNet1012D, ResNet1522D ResNet502D, # ResNet1012D, ResNet1522D ]: net = resnet_class() x = torch.randn(1, 3, 384, 384) x = torch.randn(4, 3, 160, 160) flops, params = profile(net, (x,)) print(f'{resnet_class.__name__}:') Loading Loading
zoo/monochrome/__main__.py +2 −1 Original line number Diff line number Diff line Loading @@ -78,8 +78,9 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str): _KNOWN_CKPTS: List[Tuple[str, str, int]] = [ # ('monochrome-alexnet-480.ckpt', 'alexnet', 180), # ('monochrome-resnet18-480.ckpt', 'resnet18', 180), ('monochrome-transformer-480.ckpt', 'transformer', 180), # ('monochrome-transformer-480.ckpt', 'transformer', 180), # ('monochrome-resnet18-safe2-450.ckpt', 'resnet18', 180), ('monochrome-levit_d0.2-500.ckpt', 'levit', 180), ] Loading
zoo/monochrome/resnet2d.py +5 −2 Original line number Diff line number Diff line Loading @@ -82,7 +82,9 @@ class ResNet(nn.Module): def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) print(out.shape) out = self.layer1(out) print(out.shape) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) Loading Loading @@ -139,10 +141,11 @@ if __name__ == '__main__': for resnet_class in [ ResNet182D, ResNet342D, ResNet502D, ResNet1012D, ResNet1522D ResNet502D, # ResNet1012D, ResNet1522D ]: net = resnet_class() x = torch.randn(1, 3, 384, 384) x = torch.randn(4, 3, 160, 160) flops, params = profile(net, (x,)) print(f'{resnet_class.__name__}:') Loading