Commit 4db5f451 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add new export code, ci skip

parent 0a5707ee
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -70,15 +70,16 @@ def load_model(model_name: str = "tagger_v_2_2_7"):
        )[0].last_commit.date.timestamp()

    except (KeyError, ValueError):
        alt_model_name = "tagger_v_2_2_7"
        logging.info('Cannot directly load it, load from head weights ...')
        model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu')
        model: PixAITaggerInference = get_model("pixai_tagger", model_version=alt_model_name, device='cpu')
        state_dicts = torch.load(hf_client.hf_hub_download(
            repo_id=model.model_version_map[model_name]['repo_id'],
            repo_id=model.model_version_map[alt_model_name]['repo_id'],
            repo_type='model',
            filename=model.model_version_map[model_name]['ckpt_name'],
            filename=model.model_version_map[alt_model_name]['ckpt_name'],
        ), map_location="cpu")
        state_dicts_head = torch.load(hf_client.hf_hub_download(
            repo_id=model.model_version_map[model_name],
            repo_id=model.model_version_map[alt_model_name],
            repo_type='model',
            filename=f'{model_name}.pth',
        ), map_location="cpu")
@@ -90,7 +91,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"):
        logging.info('Head weights loaded.')

        created_at = hf_client.get_paths_info(
            repo_id=model.model_version_map[model_name]['repo_id'],
            repo_id=model.model_version_map[alt_model_name]['repo_id'],
            repo_type='model',
            paths=[f'{model_name}.pth'],
            expand=True