Loading zoo/ptagger/model.py +9 −8 Original line number Diff line number Diff line Loading @@ -68,6 +68,8 @@ def load_model(model_name: str = "tagger_v_2_2_7"): paths=[model.model_version_map[model_name]['ckpt_name']], expand=True )[0].last_commit.date.timestamp() model_repo_id = model.model_version_map[model_name]['repo_id'] model_file = model.model_version_map[model_name]['ckpt_name'] except (KeyError, ValueError): alt_model_name = "tagger_v_2_2_7" Loading @@ -78,10 +80,12 @@ def load_model(model_name: str = "tagger_v_2_2_7"): repo_type='model', filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") model_repo_id = model.model_version_map[model_name]['repo_id'] model_file = f'{model_name}.pth' state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', filename=f'{model_name}.pth', filename=model_file, ), map_location="cpu") state_dicts['head.weight'] = state_dicts_head['head.0.weight'] state_dicts['head.bias'] = state_dicts_head['head.0.bias'] Loading @@ -99,7 +103,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): infer_model = model.model transforms = model.transform return model, infer_model, transforms, created_at return model, infer_model, transforms, (model_repo_id, model_file), created_at def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): Loading @@ -108,7 +112,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) raw_model, model, transforms, created_at = load_model(model_name) raw_model, model, transforms, (model_repo_id, model_filename), created_at = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) Loading Loading @@ -159,9 +163,6 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') src_repo_id = raw_model.model_version_map[raw_model.model_version]['repo_id'] src_model_filename = raw_model.model_version_map[raw_model.model_version]['ckpt_name'] with open(os.path.join(export_dir, 'meta.json'), 'w') as f: json.dump({ 'num_classes': conv_preds.shape[-1], Loading @@ -171,8 +172,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo 'name': model_name, 'model_cls': type(model).__name__, 'input_size': dummy_input.shape[2], 'repo_id': src_repo_id, 'model_filename': src_model_filename, 'repo_id': model_repo_id, 'model_filename': model_filename, 'created_at': created_at, }, f, indent=4, sort_keys=True) Loading Loading
zoo/ptagger/model.py +9 −8 Original line number Diff line number Diff line Loading @@ -68,6 +68,8 @@ def load_model(model_name: str = "tagger_v_2_2_7"): paths=[model.model_version_map[model_name]['ckpt_name']], expand=True )[0].last_commit.date.timestamp() model_repo_id = model.model_version_map[model_name]['repo_id'] model_file = model.model_version_map[model_name]['ckpt_name'] except (KeyError, ValueError): alt_model_name = "tagger_v_2_2_7" Loading @@ -78,10 +80,12 @@ def load_model(model_name: str = "tagger_v_2_2_7"): repo_type='model', filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") model_repo_id = model.model_version_map[model_name]['repo_id'] model_file = f'{model_name}.pth' state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', filename=f'{model_name}.pth', filename=model_file, ), map_location="cpu") state_dicts['head.weight'] = state_dicts_head['head.0.weight'] state_dicts['head.bias'] = state_dicts_head['head.0.bias'] Loading @@ -99,7 +103,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): infer_model = model.model transforms = model.transform return model, infer_model, transforms, created_at return model, infer_model, transforms, (model_repo_id, model_file), created_at def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): Loading @@ -108,7 +112,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) raw_model, model, transforms, created_at = load_model(model_name) raw_model, model, transforms, (model_repo_id, model_filename), created_at = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) Loading Loading @@ -159,9 +163,6 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') src_repo_id = raw_model.model_version_map[raw_model.model_version]['repo_id'] src_model_filename = raw_model.model_version_map[raw_model.model_version]['ckpt_name'] with open(os.path.join(export_dir, 'meta.json'), 'w') as f: json.dump({ 'num_classes': conv_preds.shape[-1], Loading @@ -171,8 +172,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo 'name': model_name, 'model_cls': type(model).__name__, 'input_size': dummy_input.shape[2], 'repo_id': src_repo_id, 'model_filename': src_model_filename, 'repo_id': model_repo_id, 'model_filename': model_filename, 'created_at': created_at, }, f, indent=4, sort_keys=True) Loading