Commit b14156f8 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update resnet and alexnet

parent a08e43cf
Loading
Loading
Loading
Loading
+61 −0
Original line number Diff line number Diff line
import os.path
from functools import partial
from typing import List, Tuple

import click
import torch
from huggingface_hub import hf_hub_download
from tqdm.auto import tqdm

from .onnx import export_model_to_onnx
from .train_ import _KNOWN_MODELS
from ..utils import GLOBAL_CONTEXT_SETTINGS
from ..utils import print_version as _origin_print_version

print_version = partial(_origin_print_version, 'zoo.lpips')


@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('-v', '--version', is_flag=True,
              callback=print_version, expose_value=False, is_eager=True,
              help="Utils with pixiv resources.")
def cli():
    pass  # pragma: no cover


@cli.command('export_one', help='Export one model as onnx.',
             context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('--output', '-o', 'output', type=click.Path(dir_okay=False), required=True,
              help='Output path of feature model.', show_default=True)
@click.option('--feature_bins', '-b', 'feature_bins', type=int, default=256,
              help='Feature bins of input.', show_default=True)
@click.option('--ckpt', '-c', 'ckpt', type=click.Path(exists=True, dir_okay=False), required=True,
              help='Checkpoint file to export.', show_default=True)
@click.option('--model', '-m', 'model_name', type=click.Choice(list(_KNOWN_MODELS.keys())), required=True,
              help='Name of model to export.', show_default=True)
def export_one(output: str, feature_bins: int, ckpt: str, model_name: str):
    model = _KNOWN_MODELS[model_name]().float()
    model.load_state_dict(torch.load(ckpt, map_location='cpu'))
    export_model_to_onnx(model, output, feature_bins=feature_bins)


_KNOWN_CKPTS: List[Tuple[str, str, int]] = [
    ('monochrome-alexnet_x-400.ckpt', 'alexnet', 256),
]


@cli.command('export', help='Export all models as onnx.',
             context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), required=True,
              help='Output directory of all models.', show_default=True)
def export(output_dir: str):
    for ckpt, model_name, feature_bins in tqdm(_KNOWN_CKPTS):
        model = _KNOWN_MODELS[model_name]().float()
        ckpt_file = hf_hub_download('deepghs/imgutils-models', f'monochrome/{ckpt}')
        model.load_state_dict(torch.load(ckpt_file, map_location='cpu'))
        output_file = os.path.join(output_dir, os.path.basename(ckpt))
        export_model_to_onnx(model, output_file, feature_bins=feature_bins)


if __name__ == '__main__':
    cli()
+3 −3
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import torch.nn as nn
class MonochromeAlexNet(nn.Module):
    __model_name__ = 'alexnet'

    def __init__(self, input_channels: int = 3, num_classes=2):
    def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 7):
        super(MonochromeAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(input_channels, 96, kernel_size=11, stride=4, padding=2),
@@ -22,10 +22,10 @@ class MonochromeAlexNet(nn.Module):
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool1d(6)
        self.avgpool = nn.AdaptiveAvgPool1d(avgpool_size)
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6, 4096),
            nn.Linear(256 * avgpool_size, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),

zoo/monochrome/onnx.py

0 → 100644
+58 −0
Original line number Diff line number Diff line
import os
import tempfile

import onnx
import torch
from PIL import Image
from torch import nn

from .encode import image_encode
from ..utils import get_testfile, onnx_optimize


class ModelWithSoftMax(nn.Module):
    def __init__(self, model):
        nn.Module.__init__(self)
        self.model = model

    def forward(self, x):
        x = self.model(x)
        x = torch.softmax(x, dim=1)
        return x


def export_model_to_onnx(model, onnx_filename, opset_version: int = 14, verbose: bool = True,
                         no_optimize: bool = False, feature_bins: int = 256):
    image = Image.open(get_testfile('6125785.jpg')).convert('RGB')
    example_input = image_encode(image, bins=feature_bins, normalize=True).float().unsqueeze(0)
    model = ModelWithSoftMax(model).float()

    if torch.cuda.is_available():
        example_input = example_input.cuda()
        model = model.cuda()

    with torch.no_grad(), tempfile.TemporaryDirectory() as td:
        onnx_model_file = os.path.join(td, 'model.onnx')
        torch.onnx.export(
            model,
            example_input,
            onnx_model_file,
            verbose=verbose,
            input_names=["input"],
            output_names=["output"],

            opset_version=opset_version,
            dynamic_axes={
                "input": {0: "batch"},
                "output": {0: "batch"},
            }
        )

        model = onnx.load(onnx_model_file)
        if not no_optimize:
            model = onnx_optimize(model)

        output_model_dir, _ = os.path.split(onnx_filename)
        if output_model_dir:
            os.makedirs(output_model_dir, exist_ok=True)
        onnx.save(model, onnx_filename)
+1 −1
Original line number Diff line number Diff line
@@ -71,7 +71,7 @@ class Bottleneck(nn.Module):


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 6):
    def __init__(self, block, num_blocks, num_classes, avgpool_size: int = 8):
        super(ResNet, self).__init__()
        self.in_planes = 64