Commit 45f9473e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add new export code, ci skip

parent 58865938
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -159,8 +159,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
    s_macs, s_params = clever_format([macs, params], "%.1f")
    logging.info(f'Params: {s_params}, FLOPs: {s_macs}')

    src_repo_id = raw_model.model_version_map[model_name]['repo_id']
    src_model_filename = raw_model.model_version_map[model_name]['ckpt_name']
    src_repo_id = raw_model.model_version_map[raw_model.model_version]['repo_id']
    src_model_filename = raw_model.model_version_map[raw_model.model_version]['ckpt_name']

    with open(os.path.join(export_dir, 'meta.json'), 'w') as f:
        json.dump({
@@ -190,7 +190,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
    logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}')
    d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')}

    num_classes = raw_model.model_version_map[model_name]['num_classes']
    num_classes = raw_model.model_version_map[raw_model.model_version]['num_classes']
    logging.info(f'Num classes: {num_classes!r}')
    d_tags = {v: k for k, v in raw_model.tag_map.items()}
    r_tags = []