Loading zoo/ptagger/model.py +6 −5 Original line number Diff line number Diff line Loading @@ -70,15 +70,16 @@ def load_model(model_name: str = "tagger_v_2_2_7"): )[0].last_commit.date.timestamp() except (KeyError, ValueError): alt_model_name = "tagger_v_2_2_7" logging.info('Cannot directly load it, load from head weights ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') model: PixAITaggerInference = get_model("pixai_tagger", model_version=alt_model_name, device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name]['repo_id'], repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', filename=model.model_version_map[model_name]['ckpt_name'], filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name], repo_id=model.model_version_map[alt_model_name], repo_type='model', filename=f'{model_name}.pth', ), map_location="cpu") Loading @@ -90,7 +91,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): logging.info('Head weights loaded.') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', paths=[f'{model_name}.pth'], expand=True Loading Loading
zoo/ptagger/model.py +6 −5 Original line number Diff line number Diff line Loading @@ -70,15 +70,16 @@ def load_model(model_name: str = "tagger_v_2_2_7"): )[0].last_commit.date.timestamp() except (KeyError, ValueError): alt_model_name = "tagger_v_2_2_7" logging.info('Cannot directly load it, load from head weights ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') model: PixAITaggerInference = get_model("pixai_tagger", model_version=alt_model_name, device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name]['repo_id'], repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', filename=model.model_version_map[model_name]['ckpt_name'], filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name], repo_id=model.model_version_map[alt_model_name], repo_type='model', filename=f'{model_name}.pth', ), map_location="cpu") Loading @@ -90,7 +91,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): logging.info('Head weights loaded.') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', paths=[f'{model_name}.pth'], expand=True Loading