Commit a0f7177a authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): addd x

parent 6f5b05d1
Loading
Loading
Loading
Loading
+177 −5
Original line number Diff line number Diff line
import copy
import datetime
import json
import os.path

import numpy as np
import onnx
import onnxruntime
import pandas as pd
import torch
from PIL import Image
from ditk import logging
from hbutils.string import plural_word
from hbutils.system import TemporaryDirectory
from hfutils.operate import get_hf_fs, get_hf_client
from hfutils.cache import delete_detached_cache
from hfutils.operate import get_hf_fs, get_hf_client, upload_directory_as_directory
from hfutils.repository import hf_hub_repo_file_url
from natsort import natsorted
from procslib import get_model
from procslib.models.pixai_tagger import PixAITaggerInference
from thop import profile, clever_format
@@ -17,6 +24,7 @@ from torch import nn
from imgutils.preprocess import parse_torchvision_transforms
from test.testings import get_testfile
from zoo.utils import onnx_optimize
from zoo.wd14.tags import _get_tag_by_name


class ModuleWrapper(nn.Module):
@@ -53,7 +61,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"):
    model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu')
    infer_model = model.model
    transforms = model.transform
    return infer_model, transforms
    return model, infer_model, transforms


def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False):
@@ -62,7 +70,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo

    os.makedirs(export_dir, exist_ok=True)

    model, transforms = load_model(model_name)
    raw_model, model, transforms = load_model(model_name)
    raw_model: PixAITaggerInference
    image = Image.open(get_testfile('genshin_post.jpg'))
    dummy_input = transforms(image).unsqueeze(0)
    logging.info(f'Dummy input size: {dummy_input.shape!r}')
@@ -99,11 +108,75 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
    s_macs, s_params = clever_format([macs, params], "%.1f")
    logging.info(f'Params: {s_params}, FLOPs: {s_macs}')

    src_repo_id = raw_model.model_version_map[model_name]['repo_id']
    src_model_filename = raw_model.model_version_map[model_name]['ckpt_name']

    with open(os.path.join(export_dir, 'meta.json'), 'w') as f:
        json.dump({
            'num_classes': conv_preds.shape[-1],
            'num_features': conv_features.shape[-1],
            'params': params,
            'flops': macs,
            'name': model_name,
            'model_cls': type(model).__name__,
            'input_size': dummy_input.shape[2],
            'repo_id': src_repo_id,
            'model_filename': src_model_filename,
            'created_at': hf_client.get_paths_info(
                repo_id=src_repo_id,
                repo_type='model',
                paths=[src_model_filename],
                expand=True
            )[0].last_commit.date.timestamp(),
        }, f, indent=4, sort_keys=True)

    logging.info(f'Writing transforms:\n{transforms}')
    with open(os.path.join(export_dir, 'preprocess.json'), 'w') as f:
        json.dump({
            'stages': parse_torchvision_transforms(transforms),
        }, f, indent=4, sort_keys=True)

    df_p_tags = pd.read_csv(hf_client.hf_hub_download(
        repo_id='deepghs/site_tags',
        repo_type='dataset',
        filename='danbooru.donmai.us/tags.csv'
    ))
    logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}')
    d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')}

    num_classes = raw_model.model_version_map[model_name]['num_classes']
    logging.info(f'Num classes: {num_classes!r}')
    d_tags = {v: k for k, v in raw_model.tag_map.items()}
    r_tags = []
    for i in range(num_classes):
        category = 0 if i < raw_model.gen_tag_count else 4
        if (category, d_tags[i]) in d_p_tags:
            tag_id = d_p_tags[(category, d_tags[i])]['id']
            count = d_p_tags[(category, d_tags[i])]['post_count']
        else:
            logging.warning(f'Cannot find tag {d_tags[i]!r}, category: {category!r}.')
            tag_info = _get_tag_by_name(d_tags[i])
            if tag_info['name'] != d_tags[i]:
                logging.warning(f'Not found matching tags for {d_tags[i]!r}, will be ignored.')
                tag_id = -1
                count = -1
            else:
                logging.info(f'Tag info found from danbooru - {tag_info!r}.')
                tag_id = tag_info['id']
                count = tag_info['post_count']
        r_tags.append({
            'id': i,
            'tag_id': tag_id,
            'name': d_tags[i],
            'category': category,
            'count': count,
        })
    df_tags = pd.DataFrame(r_tags)
    tags_file = os.path.join(export_dir, 'selected_tags.csv')
    logging.info(f'Tags List:\n{df_tags}\n'
                 f'Saving to {tags_file!r} ...')
    df_tags.to_csv(tags_file, index=False)

    onnx_filename = os.path.join(export_dir, 'model.onnx')
    with TemporaryDirectory() as td:
        temp_model_onnx = os.path.join(td, 'model.onnx')
@@ -147,8 +220,107 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
        assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.'


def sync(repository: str = 'onopix/pixai-tagger-onnx'):
    hf_client = get_hf_client()
    hf_fs = get_hf_fs()
    delete_detached_cache()
    if not hf_client.repo_exists(repo_id=repository, repo_type='model'):
        hf_client.create_repo(repo_id=repository, repo_type='model', private=True)
        attr_lines = hf_fs.read_text(f'{repository}/.gitattributes').splitlines(keepends=False)
        attr_lines.append('*.json filter=lfs diff=lfs merge=lfs -text')
        attr_lines.append('*.csv filter=lfs diff=lfs merge=lfs -text')
        hf_fs.write_text(f'{repository}/.gitattributes', os.linesep.join(attr_lines))

    if hf_client.file_exists(
            repo_id=repository,
            repo_type='model',
            filename='models.parquet',
    ):
        df_models = pd.read_parquet(hf_client.hf_hub_download(
            repo_id=repository,
            repo_type='model',
            filename='models.parquet',
        ))
        d_models = {item['name']: item for item in df_models.to_dict('records')}
    else:
        d_models = {}

    for model_name in ["tagger_v_2_2_7"]:
        with TemporaryDirectory() as upload_dir:
            logging.info(f'Exporting model {model_name!r} ...')
            os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True)
            try:
                extract(
                    export_dir=os.path.join(upload_dir, model_name),
                    no_optimize=False,
                )
            except Exception:
                logging.exception(f'Error when exporting {model_name!r}, skipped.')
                continue

            with open(os.path.join(upload_dir, model_name, 'meta.json'), 'r') as f:
                meta_info = json.load(f)
            c_meta_info = copy.deepcopy(meta_info)
            d_models[meta_info['name']] = c_meta_info

            df_models = pd.DataFrame(list(d_models.values()))
            df_models = df_models.sort_values(by=['created_at'], ascending=False)
            df_models.to_parquet(os.path.join(upload_dir, 'models.parquet'), index=False)

            with open(os.path.join(upload_dir, 'README.md'), 'w') as f:
                print('---', file=f)
                print('pipeline_tag: image-classification', file=f)
                print('base_model:', file=f)
                for rid in natsorted(set(df_models['repo_id'][:100])):
                    print(f'- {rid}', file=f)
                print('language:', file=f)
                print('- en', file=f)
                print('tags:', file=f)
                print('- timm', file=f)
                print('- image', file=f)
                print('- dghs-imgutils', file=f)
                print('library_name: dghs-imgutils', file=f)
                print('---', file=f)
                print('', file=f)

                print('ONNX export version from [TIMM](https://huggingface.co/timm).', file=f)
                print('', file=f)

                print(f'# Models', file=f)
                print(f'', file=f)

                df_shown = pd.DataFrame([
                    {
                        "Name": f'[{item["name"]}]({hf_hub_repo_file_url(repo_id=item["repo_id"], repo_type="model", path=item["model_filename"])})',
                        'Params': clever_format(item["params"], "%.1f"),
                        'Flops': clever_format(item["flops"], "%.1f"),
                        'Input Size': item['input_size'],
                        "Features": item['num_features'],
                        "Classes": item['num_classes'],
                        'Model': item['model_cls'],
                        'Created At': datetime.datetime.fromtimestamp(item['created_at']).strftime('%Y-%m-%d'),
                        'created_at': item['created_at'],
                    }
                    for item in df_models.to_dict('records')
                ])
                df_shown = df_shown.sort_values(by=['created_at'], ascending=[False])
                del df_shown['created_at']
                print(f'{plural_word(len(df_shown), "ONNX model")} exported in total.', file=f)
                print(f'', file=f)
                print(df_shown.to_markdown(index=False), file=f)
                print(f'', file=f)

            upload_directory_as_directory(
                repo_id=repository,
                repo_type='model',
                local_directory=upload_dir,
                path_in_repo='.',
                message=f'Export model {model_name!r}',
            )


if __name__ == '__main__':
    logging.try_init_root(level=logging.INFO)
    extract(
        export_dir='test_ex',
    sync(
        repository='onopix/pixai-tagger-onnx'
    )