Loading zoo/ptagger/model.py +177 −5 Original line number Diff line number Diff line import copy import datetime import json import os.path import numpy as np import onnx import onnxruntime import pandas as pd import torch from PIL import Image from ditk import logging from hbutils.string import plural_word from hbutils.system import TemporaryDirectory from hfutils.operate import get_hf_fs, get_hf_client from hfutils.cache import delete_detached_cache from hfutils.operate import get_hf_fs, get_hf_client, upload_directory_as_directory from hfutils.repository import hf_hub_repo_file_url from natsort import natsorted from procslib import get_model from procslib.models.pixai_tagger import PixAITaggerInference from thop import profile, clever_format Loading @@ -17,6 +24,7 @@ from torch import nn from imgutils.preprocess import parse_torchvision_transforms from test.testings import get_testfile from zoo.utils import onnx_optimize from zoo.wd14.tags import _get_tag_by_name class ModuleWrapper(nn.Module): Loading Loading @@ -53,7 +61,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') infer_model = model.model transforms = model.transform return infer_model, transforms return model, infer_model, transforms def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): Loading @@ -62,7 +70,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) model, transforms = load_model(model_name) raw_model, model, transforms = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) logging.info(f'Dummy input size: {dummy_input.shape!r}') Loading Loading @@ -99,11 +108,75 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') src_repo_id = raw_model.model_version_map[model_name]['repo_id'] src_model_filename = raw_model.model_version_map[model_name]['ckpt_name'] with open(os.path.join(export_dir, 'meta.json'), 'w') as f: json.dump({ 'num_classes': conv_preds.shape[-1], 'num_features': conv_features.shape[-1], 'params': params, 'flops': macs, 'name': model_name, 'model_cls': type(model).__name__, 'input_size': dummy_input.shape[2], 'repo_id': src_repo_id, 'model_filename': src_model_filename, 'created_at': hf_client.get_paths_info( repo_id=src_repo_id, repo_type='model', paths=[src_model_filename], expand=True )[0].last_commit.date.timestamp(), }, f, indent=4, sort_keys=True) logging.info(f'Writing transforms:\n{transforms}') with open(os.path.join(export_dir, 'preprocess.json'), 'w') as f: json.dump({ 'stages': parse_torchvision_transforms(transforms), }, f, indent=4, sort_keys=True) df_p_tags = pd.read_csv(hf_client.hf_hub_download( repo_id='deepghs/site_tags', repo_type='dataset', filename='danbooru.donmai.us/tags.csv' )) logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}') d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')} num_classes = raw_model.model_version_map[model_name]['num_classes'] logging.info(f'Num classes: {num_classes!r}') d_tags = {v: k for k, v in raw_model.tag_map.items()} r_tags = [] for i in range(num_classes): category = 0 if i < raw_model.gen_tag_count else 4 if (category, d_tags[i]) in d_p_tags: tag_id = d_p_tags[(category, d_tags[i])]['id'] count = d_p_tags[(category, d_tags[i])]['post_count'] else: logging.warning(f'Cannot find tag {d_tags[i]!r}, category: {category!r}.') tag_info = _get_tag_by_name(d_tags[i]) if tag_info['name'] != d_tags[i]: logging.warning(f'Not found matching tags for {d_tags[i]!r}, will be ignored.') tag_id = -1 count = -1 else: logging.info(f'Tag info found from danbooru - {tag_info!r}.') tag_id = tag_info['id'] count = tag_info['post_count'] r_tags.append({ 'id': i, 'tag_id': tag_id, 'name': d_tags[i], 'category': category, 'count': count, }) df_tags = pd.DataFrame(r_tags) tags_file = os.path.join(export_dir, 'selected_tags.csv') logging.info(f'Tags List:\n{df_tags}\n' f'Saving to {tags_file!r} ...') df_tags.to_csv(tags_file, index=False) onnx_filename = os.path.join(export_dir, 'model.onnx') with TemporaryDirectory() as td: temp_model_onnx = os.path.join(td, 'model.onnx') Loading Loading @@ -147,8 +220,107 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.' def sync(repository: str = 'onopix/pixai-tagger-onnx'): hf_client = get_hf_client() hf_fs = get_hf_fs() delete_detached_cache() if not hf_client.repo_exists(repo_id=repository, repo_type='model'): hf_client.create_repo(repo_id=repository, repo_type='model', private=True) attr_lines = hf_fs.read_text(f'{repository}/.gitattributes').splitlines(keepends=False) attr_lines.append('*.json filter=lfs diff=lfs merge=lfs -text') attr_lines.append('*.csv filter=lfs diff=lfs merge=lfs -text') hf_fs.write_text(f'{repository}/.gitattributes', os.linesep.join(attr_lines)) if hf_client.file_exists( repo_id=repository, repo_type='model', filename='models.parquet', ): df_models = pd.read_parquet(hf_client.hf_hub_download( repo_id=repository, repo_type='model', filename='models.parquet', )) d_models = {item['name']: item for item in df_models.to_dict('records')} else: d_models = {} for model_name in ["tagger_v_2_2_7"]: with TemporaryDirectory() as upload_dir: logging.info(f'Exporting model {model_name!r} ...') os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True) try: extract( export_dir=os.path.join(upload_dir, model_name), no_optimize=False, ) except Exception: logging.exception(f'Error when exporting {model_name!r}, skipped.') continue with open(os.path.join(upload_dir, model_name, 'meta.json'), 'r') as f: meta_info = json.load(f) c_meta_info = copy.deepcopy(meta_info) d_models[meta_info['name']] = c_meta_info df_models = pd.DataFrame(list(d_models.values())) df_models = df_models.sort_values(by=['created_at'], ascending=False) df_models.to_parquet(os.path.join(upload_dir, 'models.parquet'), index=False) with open(os.path.join(upload_dir, 'README.md'), 'w') as f: print('---', file=f) print('pipeline_tag: image-classification', file=f) print('base_model:', file=f) for rid in natsorted(set(df_models['repo_id'][:100])): print(f'- {rid}', file=f) print('language:', file=f) print('- en', file=f) print('tags:', file=f) print('- timm', file=f) print('- image', file=f) print('- dghs-imgutils', file=f) print('library_name: dghs-imgutils', file=f) print('---', file=f) print('', file=f) print('ONNX export version from [TIMM](https://huggingface.co/timm).', file=f) print('', file=f) print(f'# Models', file=f) print(f'', file=f) df_shown = pd.DataFrame([ { "Name": f'[{item["name"]}]({hf_hub_repo_file_url(repo_id=item["repo_id"], repo_type="model", path=item["model_filename"])})', 'Params': clever_format(item["params"], "%.1f"), 'Flops': clever_format(item["flops"], "%.1f"), 'Input Size': item['input_size'], "Features": item['num_features'], "Classes": item['num_classes'], 'Model': item['model_cls'], 'Created At': datetime.datetime.fromtimestamp(item['created_at']).strftime('%Y-%m-%d'), 'created_at': item['created_at'], } for item in df_models.to_dict('records') ]) df_shown = df_shown.sort_values(by=['created_at'], ascending=[False]) del df_shown['created_at'] print(f'{plural_word(len(df_shown), "ONNX model")} exported in total.', file=f) print(f'', file=f) print(df_shown.to_markdown(index=False), file=f) print(f'', file=f) upload_directory_as_directory( repo_id=repository, repo_type='model', local_directory=upload_dir, path_in_repo='.', message=f'Export model {model_name!r}', ) if __name__ == '__main__': logging.try_init_root(level=logging.INFO) extract( export_dir='test_ex', sync( repository='onopix/pixai-tagger-onnx' ) Loading
zoo/ptagger/model.py +177 −5 Original line number Diff line number Diff line import copy import datetime import json import os.path import numpy as np import onnx import onnxruntime import pandas as pd import torch from PIL import Image from ditk import logging from hbutils.string import plural_word from hbutils.system import TemporaryDirectory from hfutils.operate import get_hf_fs, get_hf_client from hfutils.cache import delete_detached_cache from hfutils.operate import get_hf_fs, get_hf_client, upload_directory_as_directory from hfutils.repository import hf_hub_repo_file_url from natsort import natsorted from procslib import get_model from procslib.models.pixai_tagger import PixAITaggerInference from thop import profile, clever_format Loading @@ -17,6 +24,7 @@ from torch import nn from imgutils.preprocess import parse_torchvision_transforms from test.testings import get_testfile from zoo.utils import onnx_optimize from zoo.wd14.tags import _get_tag_by_name class ModuleWrapper(nn.Module): Loading Loading @@ -53,7 +61,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') infer_model = model.model transforms = model.transform return infer_model, transforms return model, infer_model, transforms def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): Loading @@ -62,7 +70,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) model, transforms = load_model(model_name) raw_model, model, transforms = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) logging.info(f'Dummy input size: {dummy_input.shape!r}') Loading Loading @@ -99,11 +108,75 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') src_repo_id = raw_model.model_version_map[model_name]['repo_id'] src_model_filename = raw_model.model_version_map[model_name]['ckpt_name'] with open(os.path.join(export_dir, 'meta.json'), 'w') as f: json.dump({ 'num_classes': conv_preds.shape[-1], 'num_features': conv_features.shape[-1], 'params': params, 'flops': macs, 'name': model_name, 'model_cls': type(model).__name__, 'input_size': dummy_input.shape[2], 'repo_id': src_repo_id, 'model_filename': src_model_filename, 'created_at': hf_client.get_paths_info( repo_id=src_repo_id, repo_type='model', paths=[src_model_filename], expand=True )[0].last_commit.date.timestamp(), }, f, indent=4, sort_keys=True) logging.info(f'Writing transforms:\n{transforms}') with open(os.path.join(export_dir, 'preprocess.json'), 'w') as f: json.dump({ 'stages': parse_torchvision_transforms(transforms), }, f, indent=4, sort_keys=True) df_p_tags = pd.read_csv(hf_client.hf_hub_download( repo_id='deepghs/site_tags', repo_type='dataset', filename='danbooru.donmai.us/tags.csv' )) logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}') d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')} num_classes = raw_model.model_version_map[model_name]['num_classes'] logging.info(f'Num classes: {num_classes!r}') d_tags = {v: k for k, v in raw_model.tag_map.items()} r_tags = [] for i in range(num_classes): category = 0 if i < raw_model.gen_tag_count else 4 if (category, d_tags[i]) in d_p_tags: tag_id = d_p_tags[(category, d_tags[i])]['id'] count = d_p_tags[(category, d_tags[i])]['post_count'] else: logging.warning(f'Cannot find tag {d_tags[i]!r}, category: {category!r}.') tag_info = _get_tag_by_name(d_tags[i]) if tag_info['name'] != d_tags[i]: logging.warning(f'Not found matching tags for {d_tags[i]!r}, will be ignored.') tag_id = -1 count = -1 else: logging.info(f'Tag info found from danbooru - {tag_info!r}.') tag_id = tag_info['id'] count = tag_info['post_count'] r_tags.append({ 'id': i, 'tag_id': tag_id, 'name': d_tags[i], 'category': category, 'count': count, }) df_tags = pd.DataFrame(r_tags) tags_file = os.path.join(export_dir, 'selected_tags.csv') logging.info(f'Tags List:\n{df_tags}\n' f'Saving to {tags_file!r} ...') df_tags.to_csv(tags_file, index=False) onnx_filename = os.path.join(export_dir, 'model.onnx') with TemporaryDirectory() as td: temp_model_onnx = os.path.join(td, 'model.onnx') Loading Loading @@ -147,8 +220,107 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.' def sync(repository: str = 'onopix/pixai-tagger-onnx'): hf_client = get_hf_client() hf_fs = get_hf_fs() delete_detached_cache() if not hf_client.repo_exists(repo_id=repository, repo_type='model'): hf_client.create_repo(repo_id=repository, repo_type='model', private=True) attr_lines = hf_fs.read_text(f'{repository}/.gitattributes').splitlines(keepends=False) attr_lines.append('*.json filter=lfs diff=lfs merge=lfs -text') attr_lines.append('*.csv filter=lfs diff=lfs merge=lfs -text') hf_fs.write_text(f'{repository}/.gitattributes', os.linesep.join(attr_lines)) if hf_client.file_exists( repo_id=repository, repo_type='model', filename='models.parquet', ): df_models = pd.read_parquet(hf_client.hf_hub_download( repo_id=repository, repo_type='model', filename='models.parquet', )) d_models = {item['name']: item for item in df_models.to_dict('records')} else: d_models = {} for model_name in ["tagger_v_2_2_7"]: with TemporaryDirectory() as upload_dir: logging.info(f'Exporting model {model_name!r} ...') os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True) try: extract( export_dir=os.path.join(upload_dir, model_name), no_optimize=False, ) except Exception: logging.exception(f'Error when exporting {model_name!r}, skipped.') continue with open(os.path.join(upload_dir, model_name, 'meta.json'), 'r') as f: meta_info = json.load(f) c_meta_info = copy.deepcopy(meta_info) d_models[meta_info['name']] = c_meta_info df_models = pd.DataFrame(list(d_models.values())) df_models = df_models.sort_values(by=['created_at'], ascending=False) df_models.to_parquet(os.path.join(upload_dir, 'models.parquet'), index=False) with open(os.path.join(upload_dir, 'README.md'), 'w') as f: print('---', file=f) print('pipeline_tag: image-classification', file=f) print('base_model:', file=f) for rid in natsorted(set(df_models['repo_id'][:100])): print(f'- {rid}', file=f) print('language:', file=f) print('- en', file=f) print('tags:', file=f) print('- timm', file=f) print('- image', file=f) print('- dghs-imgutils', file=f) print('library_name: dghs-imgutils', file=f) print('---', file=f) print('', file=f) print('ONNX export version from [TIMM](https://huggingface.co/timm).', file=f) print('', file=f) print(f'# Models', file=f) print(f'', file=f) df_shown = pd.DataFrame([ { "Name": f'[{item["name"]}]({hf_hub_repo_file_url(repo_id=item["repo_id"], repo_type="model", path=item["model_filename"])})', 'Params': clever_format(item["params"], "%.1f"), 'Flops': clever_format(item["flops"], "%.1f"), 'Input Size': item['input_size'], "Features": item['num_features'], "Classes": item['num_classes'], 'Model': item['model_cls'], 'Created At': datetime.datetime.fromtimestamp(item['created_at']).strftime('%Y-%m-%d'), 'created_at': item['created_at'], } for item in df_models.to_dict('records') ]) df_shown = df_shown.sort_values(by=['created_at'], ascending=[False]) del df_shown['created_at'] print(f'{plural_word(len(df_shown), "ONNX model")} exported in total.', file=f) print(f'', file=f) print(df_shown.to_markdown(index=False), file=f) print(f'', file=f) upload_directory_as_directory( repo_id=repository, repo_type='model', local_directory=upload_dir, path_in_repo='.', message=f'Export model {model_name!r}', ) if __name__ == '__main__': logging.try_init_root(level=logging.INFO) extract( export_dir='test_ex', sync( repository='onopix/pixai-tagger-onnx' )