Commit ad5cbb05 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add new export code

parent d9563dee
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -60,6 +60,7 @@ class ModuleWrapper(nn.Module):
def load_model(model_name: str = "tagger_v_2_2_7"):
    hf_client = get_hf_client()
    try:
        logging.info(f'Try loading model {model_name!r} ...')
        model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu')
        created_at = hf_client.get_paths_info(
            repo_id=model.model_version_map[model_name]['repo_id'],
@@ -69,6 +70,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"):
        )[0].last_commit.date.timestamp()

    except KeyError:
        logging.info('Cannot directly load it, load from head weights ...')
        model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu')
        state_dicts = torch.load(hf_client.hf_hub_download(
            repo_id=model.model_version_map[model_name]['repo_id'],
@@ -85,6 +87,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"):
        model.model.load_state_dict(state_dicts)
        model.model = model.model.to(model.device)
        model.model.eval()
        logging.info('Head weights loaded.')

        created_at = hf_client.get_paths_info(
            repo_id=model.model_version_map[model_name]['repo_id'],
@@ -294,6 +297,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'):
            try:
                extract(
                    export_dir=os.path.join(upload_dir, model_name),
                    model_name=model_name,
                    no_optimize=False,
                )
            except Exception: