Commit ab850fbf authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save those buggy code

parent c800b41b
Loading
Loading
Loading
Loading
+25 −16
Original line number Diff line number Diff line
import os.path

import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
@@ -7,13 +8,14 @@ from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.preprocess.pillow import PillowConvertRGB, PillowResize, PillowToTensor, PillowCompose
from test.testings import get_testfile
from .model import DeepDanbooruModel
from .model import DeepDanbooruModel, ModelWithoutTags

TORCH_DTYPE = torch.float32


def load_model():
    return DeepDanbooruModel.from_single_file(
def load_model(no_tags: bool = False):
    cls = DeepDanbooruModel if not no_tags else ModelWithoutTags
    return cls.from_single_file(
        hf_hub_download(
            repo_id='v2ray/deepgelbooru',
            repo_type='model',
@@ -24,18 +26,24 @@ def load_model():
    )


# _PIC_FILE = get_testfile('nude_girl.png')
_PIC_FILE = get_testfile('6125785.jpg')
_PIC_FILE = get_testfile('nude_girl.png')


def get_dummy_input():
    pic = load_image(_PIC_FILE, mode='RGB')
    compose = PillowCompose([
# _PIC_FILE = get_testfile('6125785.jpg')


def get_preprocessor():
    return PillowCompose([
        PillowConvertRGB(),
        PillowResize((512, 512)),
        PillowToTensor(),
    ])
    return compose(pic).transpose((1, 2, 0))[None, ...]


def get_dummy_input():
    pic = load_image(_PIC_FILE, mode='RGB')
    compose = get_preprocessor()
    return compose(pic).transpose((1, 2, 0))[None, ...].astype(np.float32)


def load_tags_list():
@@ -66,7 +74,7 @@ def load_tags_list():


if __name__ == '__main__':
    model = load_model()
    model = load_model(no_tags=True)
    print(model)
    # quit()

@@ -78,11 +86,12 @@ if __name__ == '__main__':

    df_tags = load_tags_list()
    print(df_tags)
    assert len(df_tags) == len(model.tags)
    for item in df_tags.to_dict('records'):
        assert model.tags[item['tag_id']] == item['name'], \
            f'Tag #{item["tag_id"]!r} not match, {item["name"]!r} expected but {model.tags[item["tag_id"]]!r} found.'
    d_tags = {item['tag_id']: item for item in df_tags.to_dict('records')}
    # assert len(df_tags) == len(model.tags)
    # for item in df_tags.to_dict('records'):
    #     assert model.tags[item['tag_id']] == item['name'], \
    #         f'Tag #{item["tag_id"]!r} not match, {item["name"]!r} expected but {model.tags[item["tag_id"]]!r} found.'

    for i, prob in sorted(((i, float(prob)) for i, prob in enumerate(y)), key=lambda x: x[1]):
        if prob >= 0.3:
            print(model.tags[i], "-", prob)
        if prob >= 0.2:
            print(d_tags[i]['name'], "-", prob)
+76 −0
Original line number Diff line number Diff line
import json
import os
import tempfile

import onnx
import torch
from hbutils.system import TemporaryDirectory
from hfutils.operate import upload_directory_as_directory

from imgutils.preprocess import parse_pillow_transforms
from zoo.utils import onnx_optimize
from .demo import load_model, get_dummy_input, load_tags_list, get_preprocessor, TORCH_DTYPE


def export_model_to_onnx(model, onnx_filename, opset_version: int = 17, verbose: bool = True,
                         no_optimize: bool = False):
    dummy_input = torch.from_numpy(get_dummy_input()).to('cpu').type(TORCH_DTYPE)
    with torch.no_grad(), tempfile.TemporaryDirectory() as td:
        onnx_model_file = os.path.join(td, 'model.onnx')
        torch.onnx.export(
            model,
            dummy_input,
            onnx_model_file,
            verbose=verbose,
            input_names=["input"],
            output_names=["prediction"],

            opset_version=opset_version,
            dynamic_axes={
                "input": {0: "batch"},
                "prediction": {0: "batch"},
            }
        )

        model = onnx.load(onnx_model_file)
        if not no_optimize:
            model = onnx_optimize(model)

        output_model_dir, _ = os.path.split(onnx_filename)
        if output_model_dir:
            os.makedirs(output_model_dir, exist_ok=True)
        onnx.save(model, onnx_filename)


def export(repository: str):
    with TemporaryDirectory() as upload_dir:
        model = load_model(no_tags=True)
        # print(model.tags)
        # del model.tags
        # del model.num_tags
        export_model_to_onnx(
            model=model,
            onnx_filename=os.path.join(upload_dir, 'model.onnx')
        )

        df_tags = load_tags_list()
        df_tags.to_csv(os.path.join(upload_dir, 'tags.csv'), index=False)

        with open(os.path.join(upload_dir, 'preprocessor.json'), 'w') as f:
            json.dump({
                'stages': parse_pillow_transforms(get_preprocessor()),
            }, f, indent=4, sort_keys=True, ensure_ascii=False)

        upload_directory_as_directory(
            repo_id=repository,
            repo_type='model',
            local_directory=upload_dir,
            path_in_repo='.',
            message='Syncing deepgelbooru ONNX model',
        )


if __name__ == '__main__':
    export(
        repository='deepghs/deepgelbooru_onnx',
    )
+10 −1
Original line number Diff line number Diff line
@@ -702,3 +702,12 @@ class DeepDanbooruModel(nn.Module):
            model.to(dtype=torch_dtype)
        model.to(device=device_map)
        return model


class ModelWithoutTags(DeepDanbooruModel):
    def load_state_dict(self, state_dict, **kwargs):
        state_dict.pop("tags", None)
        nn.Module.load_state_dict(self, state_dict, **kwargs)

    def state_dict(self, **kwargs):
        return nn.Module.state_dict(self, **kwargs)