Loading zoo/ptagger/model.py +4 −0 Original line number Diff line number Diff line Loading @@ -60,6 +60,7 @@ class ModuleWrapper(nn.Module): def load_model(model_name: str = "tagger_v_2_2_7"): hf_client = get_hf_client() try: logging.info(f'Try loading model {model_name!r} ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], Loading @@ -69,6 +70,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): )[0].last_commit.date.timestamp() except KeyError: logging.info('Cannot directly load it, load from head weights ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name]['repo_id'], Loading @@ -85,6 +87,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): model.model.load_state_dict(state_dicts) model.model = model.model.to(model.device) model.model.eval() logging.info('Head weights loaded.') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], Loading Loading @@ -294,6 +297,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'): try: extract( export_dir=os.path.join(upload_dir, model_name), model_name=model_name, no_optimize=False, ) except Exception: Loading Loading
zoo/ptagger/model.py +4 −0 Original line number Diff line number Diff line Loading @@ -60,6 +60,7 @@ class ModuleWrapper(nn.Module): def load_model(model_name: str = "tagger_v_2_2_7"): hf_client = get_hf_client() try: logging.info(f'Try loading model {model_name!r} ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], Loading @@ -69,6 +70,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): )[0].last_commit.date.timestamp() except KeyError: logging.info('Cannot directly load it, load from head weights ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name]['repo_id'], Loading @@ -85,6 +87,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): model.model.load_state_dict(state_dicts) model.model = model.model.to(model.device) model.model.eval() logging.info('Head weights loaded.') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], Loading Loading @@ -294,6 +297,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'): try: extract( export_dir=os.path.join(upload_dir, model_name), model_name=model_name, no_optimize=False, ) except Exception: Loading