Commit ed681c63 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update them, ci skip

parent 69617993
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ from natsort import natsorted
from procslib import get_model
from procslib.models.pixai_tagger import PixAITaggerInference
from thop import profile, clever_format
from timm.models._hub import save_for_hf
from torch import nn

from imgutils.preprocess import parse_torchvision_transforms
@@ -163,6 +164,13 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
    s_macs, s_params = clever_format([macs, params], "%.1f")
    logging.info(f'Params: {s_params}, FLOPs: {s_macs}')

    logging.info('Exporting model weights ...')
    save_for_hf(
        model,
        expected_logits,
        safe_serialization='both',
    )

    with open(os.path.join(export_dir, 'meta.json'), 'w') as f:
        json.dump({
            'num_classes': conv_preds.shape[-1],