Commit 0531ed03 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update them, ci skip

parent 45f9473e
Loading
Loading
Loading
Loading
+9 −8
Original line number Diff line number Diff line
@@ -68,6 +68,8 @@ def load_model(model_name: str = "tagger_v_2_2_7"):
            paths=[model.model_version_map[model_name]['ckpt_name']],
            expand=True
        )[0].last_commit.date.timestamp()
        model_repo_id = model.model_version_map[model_name]['repo_id']
        model_file = model.model_version_map[model_name]['ckpt_name']

    except (KeyError, ValueError):
        alt_model_name = "tagger_v_2_2_7"
@@ -78,10 +80,12 @@ def load_model(model_name: str = "tagger_v_2_2_7"):
            repo_type='model',
            filename=model.model_version_map[alt_model_name]['ckpt_name'],
        ), map_location="cpu")
        model_repo_id = model.model_version_map[model_name]['repo_id']
        model_file = f'{model_name}.pth'
        state_dicts_head = torch.load(hf_client.hf_hub_download(
            repo_id=model.model_version_map[alt_model_name]['repo_id'],
            repo_type='model',
            filename=f'{model_name}.pth',
            filename=model_file,
        ), map_location="cpu")
        state_dicts['head.weight'] = state_dicts_head['head.0.weight']
        state_dicts['head.bias'] = state_dicts_head['head.0.bias']
@@ -99,7 +103,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"):

    infer_model = model.model
    transforms = model.transform
    return model, infer_model, transforms, created_at
    return model, infer_model, transforms, (model_repo_id, model_file), created_at


def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False):
@@ -108,7 +112,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo

    os.makedirs(export_dir, exist_ok=True)

    raw_model, model, transforms, created_at = load_model(model_name)
    raw_model, model, transforms, (model_repo_id, model_filename), created_at = load_model(model_name)
    raw_model: PixAITaggerInference
    image = Image.open(get_testfile('genshin_post.jpg'))
    dummy_input = transforms(image).unsqueeze(0)
@@ -159,9 +163,6 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
    s_macs, s_params = clever_format([macs, params], "%.1f")
    logging.info(f'Params: {s_params}, FLOPs: {s_macs}')

    src_repo_id = raw_model.model_version_map[raw_model.model_version]['repo_id']
    src_model_filename = raw_model.model_version_map[raw_model.model_version]['ckpt_name']

    with open(os.path.join(export_dir, 'meta.json'), 'w') as f:
        json.dump({
            'num_classes': conv_preds.shape[-1],
@@ -171,8 +172,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
            'name': model_name,
            'model_cls': type(model).__name__,
            'input_size': dummy_input.shape[2],
            'repo_id': src_repo_id,
            'model_filename': src_model_filename,
            'repo_id': model_repo_id,
            'model_filename': model_filename,
            'created_at': created_at,
        }, f, indent=4, sort_keys=True)