Commit d9563dee authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add new export code

parent 22085242
Loading
Loading
Loading
Loading
+39 −10
Original line number Diff line number Diff line
@@ -58,10 +58,44 @@ class ModuleWrapper(nn.Module):


def load_model(model_name: str = "tagger_v_2_2_7"):
    hf_client = get_hf_client()
    try:
        model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu')
        created_at = hf_client.get_paths_info(
            repo_id=model.model_version_map[model_name]['repo_id'],
            repo_type='model',
            paths=[model.model_version_map[model_name]['ckpt_name']],
            expand=True
        )[0].last_commit.date.timestamp()

    except KeyError:
        model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu')
        state_dicts = torch.load(hf_client.hf_hub_download(
            repo_id=model.model_version_map[model_name]['repo_id'],
            repo_type='model',
            filename=model.model_version_map[model_name]['ckpt_name'],
        ), map_location="cpu")
        state_dicts_head = torch.load(hf_client.hf_hub_download(
            repo_id=model.model_version_map[model_name],
            repo_type='model',
            filename=f'{model_name}.pth',
        ), map_location="cpu")
        state_dicts['head.weight'] = state_dicts_head['head.0.weight']
        state_dicts['head.bias'] = state_dicts_head['head.0.bias']
        model.model.load_state_dict(state_dicts)
        model.model = model.model.to(model.device)
        model.model.eval()

        created_at = hf_client.get_paths_info(
            repo_id=model.model_version_map[model_name]['repo_id'],
            repo_type='model',
            paths=[f'{model_name}.pth'],
            expand=True
        )[0].last_commit.date.timestamp()

    infer_model = model.model
    transforms = model.transform
    return model, infer_model, transforms
    return model, infer_model, transforms, created_at


def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False):
@@ -70,7 +104,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo

    os.makedirs(export_dir, exist_ok=True)

    raw_model, model, transforms = load_model(model_name)
    raw_model, model, transforms, created_at = load_model(model_name)
    raw_model: PixAITaggerInference
    image = Image.open(get_testfile('genshin_post.jpg'))
    dummy_input = transforms(image).unsqueeze(0)
@@ -135,12 +169,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
            'input_size': dummy_input.shape[2],
            'repo_id': src_repo_id,
            'model_filename': src_model_filename,
            'created_at': hf_client.get_paths_info(
                repo_id=src_repo_id,
                repo_type='model',
                paths=[src_model_filename],
                expand=True
            )[0].last_commit.date.timestamp(),
            'created_at': created_at,
        }, f, indent=4, sort_keys=True)

    logging.info(f'Writing transforms:\n{transforms}')
@@ -258,7 +287,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'):
    else:
        d_models = {}

    for model_name in ["tagger_v_2_2_7"]:
    for model_name in ["tagger_v_2_3_2", "tagger_v_2_2_7"]:
        with TemporaryDirectory() as upload_dir:
            logging.info(f'Exporting model {model_name!r} ...')
            os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True)