Loading zoo/ptagger/model.py +39 −10 Original line number Diff line number Diff line Loading @@ -58,10 +58,44 @@ class ModuleWrapper(nn.Module): def load_model(model_name: str = "tagger_v_2_2_7"): hf_client = get_hf_client() try: model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], repo_type='model', paths=[model.model_version_map[model_name]['ckpt_name']], expand=True )[0].last_commit.date.timestamp() except KeyError: model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name]['repo_id'], repo_type='model', filename=model.model_version_map[model_name]['ckpt_name'], ), map_location="cpu") state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name], repo_type='model', filename=f'{model_name}.pth', ), map_location="cpu") state_dicts['head.weight'] = state_dicts_head['head.0.weight'] state_dicts['head.bias'] = state_dicts_head['head.0.bias'] model.model.load_state_dict(state_dicts) model.model = model.model.to(model.device) model.model.eval() created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], repo_type='model', paths=[f'{model_name}.pth'], expand=True )[0].last_commit.date.timestamp() infer_model = model.model transforms = model.transform return model, infer_model, transforms return model, infer_model, transforms, created_at def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): Loading @@ -70,7 +104,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) raw_model, model, transforms = load_model(model_name) raw_model, model, transforms, created_at = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) Loading Loading @@ -135,12 +169,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo 'input_size': dummy_input.shape[2], 'repo_id': src_repo_id, 'model_filename': src_model_filename, 'created_at': hf_client.get_paths_info( repo_id=src_repo_id, repo_type='model', paths=[src_model_filename], expand=True )[0].last_commit.date.timestamp(), 'created_at': created_at, }, f, indent=4, sort_keys=True) logging.info(f'Writing transforms:\n{transforms}') Loading Loading @@ -258,7 +287,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'): else: d_models = {} for model_name in ["tagger_v_2_2_7"]: for model_name in ["tagger_v_2_3_2", "tagger_v_2_2_7"]: with TemporaryDirectory() as upload_dir: logging.info(f'Exporting model {model_name!r} ...') os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True) Loading Loading
zoo/ptagger/model.py +39 −10 Original line number Diff line number Diff line Loading @@ -58,10 +58,44 @@ class ModuleWrapper(nn.Module): def load_model(model_name: str = "tagger_v_2_2_7"): hf_client = get_hf_client() try: model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], repo_type='model', paths=[model.model_version_map[model_name]['ckpt_name']], expand=True )[0].last_commit.date.timestamp() except KeyError: model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name]['repo_id'], repo_type='model', filename=model.model_version_map[model_name]['ckpt_name'], ), map_location="cpu") state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name], repo_type='model', filename=f'{model_name}.pth', ), map_location="cpu") state_dicts['head.weight'] = state_dicts_head['head.0.weight'] state_dicts['head.bias'] = state_dicts_head['head.0.bias'] model.model.load_state_dict(state_dicts) model.model = model.model.to(model.device) model.model.eval() created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], repo_type='model', paths=[f'{model_name}.pth'], expand=True )[0].last_commit.date.timestamp() infer_model = model.model transforms = model.transform return model, infer_model, transforms return model, infer_model, transforms, created_at def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): Loading @@ -70,7 +104,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) raw_model, model, transforms = load_model(model_name) raw_model, model, transforms, created_at = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) Loading Loading @@ -135,12 +169,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo 'input_size': dummy_input.shape[2], 'repo_id': src_repo_id, 'model_filename': src_model_filename, 'created_at': hf_client.get_paths_info( repo_id=src_repo_id, repo_type='model', paths=[src_model_filename], expand=True )[0].last_commit.date.timestamp(), 'created_at': created_at, }, f, indent=4, sort_keys=True) logging.info(f'Writing transforms:\n{transforms}') Loading Loading @@ -258,7 +287,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'): else: d_models = {} for model_name in ["tagger_v_2_2_7"]: for model_name in ["tagger_v_2_3_2", "tagger_v_2_2_7"]: with TemporaryDirectory() as upload_dir: logging.info(f'Exporting model {model_name!r} ...') os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True) Loading