Commit 25399883 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix those issues

parent eb024208
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ _CATEGORY_MAPS = {


@ts_lru_cache()
def _get_camie_model(model_name, is_full: bool):
def _get_camie_model(model_name, is_full: bool = True):
    """
    Load and cache a Camie ONNX model from the Hugging Face Hub.

+57 −3
Original line number Diff line number Diff line
@@ -23,7 +23,7 @@ from thop import profile, clever_format
from imgutils.data import load_image
from imgutils.preprocess import parse_pillow_transforms, create_torchvision_transforms, parse_torchvision_transforms
from imgutils.preprocess.pillow import PillowPadToSize, PillowToTensor, PillowCompose
from .model import create_initial_model, create_refined_model, InitialOnlyWrapper, FullWrapper
from .model import create_initial_model, create_refined_model, InitialOnlyWrapper, FullWrapper, EmbToPredWrapper
from .tags import load_tags
from ..utils import onnx_optimize

@@ -39,7 +39,7 @@ _MODEL_MAP = {


def export_onnx_model(model, dummy_input, onnx_filename: str, is_full: bool = True,
                      opset_version: int = 17, verbose: bool = True, no_optimize: bool = False):
                      opset_version: int = 14, verbose: bool = True, no_optimize: bool = False):
    if not is_full:
        wrapped_model = InitialOnlyWrapper(model)
    else:
@@ -85,6 +85,40 @@ def export_onnx_model(model, dummy_input, onnx_filename: str, is_full: bool = Tr
        onnx.save(model, onnx_filename)


def export_emb_to_pred_onnx_model(model, dummy_input, onnx_filename: str,
                                  opset_version: int = 14, verbose: bool = True, no_optimize: bool = False):
    wrapped_model = EmbToPredWrapper(model)

    with torch.no_grad(), tempfile.TemporaryDirectory() as td:
        onnx_model_file = os.path.join(td, 'model.onnx')
        torch.onnx.export(
            wrapped_model,
            dummy_input,
            onnx_model_file,
            verbose=verbose,
            input_names=["embedding"],
            output_names=(
                ["logits", "output"]
            ),

            opset_version=opset_version,
            dynamic_axes={
                "embedding": {0: "batch"},
                "logits": {0: "batch"},
                "output": {0: "batch"},
            }
        )

        model = onnx.load(onnx_model_file)
        if not no_optimize:
            model = onnx_optimize(model)

        output_model_dir, _ = os.path.split(onnx_filename)
        if output_model_dir:
            os.makedirs(output_model_dir, exist_ok=True)
        onnx.save(model, onnx_filename)


def get_threshold(model_name: str = 'initial'):
    with open(hf_hub_download(
            repo_id='Camais03/camie-tagger',
@@ -98,8 +132,10 @@ def extract(export_dir: str, model_name: str = "initial", no_optimize: bool = Fa
    os.makedirs(export_dir, exist_ok=True)
    tp, model_fn = _MODEL_MAP[model_name]
    tprocess = create_torchvision_transforms(tp)
    model, created_at, (model_repo_id, model_filename) = model_fn()
    model, created_at, (model_repo_id, model_filename), (initial_emb_to_pred, refined_emb_to_pred) = model_fn()
    model = model.eval()
    initial_emb_to_pred.eval()
    refined_emb_to_pred.eval()

    sample_image = load_image(os.path.join('zoo', 'testfile', '6125785.jpg'), mode='RGB', force_background='white')
    dummy_input = tprocess(sample_image).unsqueeze(0)
@@ -157,6 +193,24 @@ def extract(export_dir: str, model_name: str = "initial", no_optimize: bool = Fa
        filename=model_weights_file,
    )

    model_initial_emb_to_pred_onnx_file = os.path.join(export_dir, 'initial_emb_to_pred.onnx')
    logging.info(f'Exporting initial emb to pred model tp {model_initial_emb_to_pred_onnx_file!r} ...')
    export_emb_to_pred_onnx_model(
        model=initial_emb_to_pred,
        dummy_input=dummy_init_embeddings,
        onnx_filename=model_initial_emb_to_pred_onnx_file,
        no_optimize=no_optimize,
    )

    model_refined_emb_to_pred_onnx_file = os.path.join(export_dir, 'refined_emb_to_pred.onnx')
    logging.info(f'Exporting refined emb to pred model to {model_refined_emb_to_pred_onnx_file!r} ...')
    export_emb_to_pred_onnx_model(
        model=refined_emb_to_pred,
        dummy_input=dummy_refined_embeddings,
        onnx_filename=model_refined_emb_to_pred_onnx_file,
        no_optimize=no_optimize,
    )

    model_onnx_file = os.path.join(export_dir, 'model.onnx')
    logging.info(f'Exporting full model to {model_onnx_file!r} ...')
    export_onnx_model(
+1 −1
Original line number Diff line number Diff line
from .initial import create_initial_model
from .refined import create_refined_model
from .wrapper import FullWrapper, InitialOnlyWrapper
from .wrapper import FullWrapper, InitialOnlyWrapper, EmbToPredWrapper
+19 −4
Original line number Diff line number Diff line
@@ -41,6 +41,12 @@ class CamieTaggerInitial(nn.Module):
        # Temperature scaling
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def emb_to_pred(self, features):
        # Initial Tag Predictions
        initial_logits = self.initial_classifier(features)
        initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
        return initial_preds

    def forward(self, x):
        """Forward pass with only the initial predictions"""
        # Image Feature Extraction
@@ -48,13 +54,21 @@ class CamieTaggerInitial(nn.Module):
        features = self.spatial_pool(features).squeeze(-1).squeeze(-1)

        # Initial Tag Predictions
        initial_logits = self.initial_classifier(features)
        initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
        initial_preds = self.emb_to_pred(features)

        # For API compatibility with the full model, return the same predictions twice
        return features, initial_preds, features, initial_preds


class InitialEmbToPred(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model: CamieTaggerInitial = model

    def forward(self, features):
        return self.model.emb_to_pred(features)


def create_initial_model():
    repo_id = 'Camais03/camie-tagger'
    filename = 'model_initial.safetensors'
@@ -75,11 +89,12 @@ def create_initial_model():
        filename=filename
    )

    return model, created_at, (repo_id, filename)
    return model, created_at, (repo_id, filename), \
        (InitialEmbToPred(model), InitialEmbToPred(model))


if __name__ == '__main__':
    model, created_at, _ = create_initial_model()
    model, created_at, _, _ = create_initial_model()
    model.eval()  # set to evaluation mode
    print(model)

+34 −7
Original line number Diff line number Diff line
@@ -129,15 +129,24 @@ class CamieTaggerRefined(nn.Module):
        # Temperature parameter for scaling logits (to calibrate confidence)
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def initial_emb_to_pred(self, features):
        initial_logits = self.initial_classifier(features)  # [B, total_tags]
        # Scale by temperature and clamp (to stabilize extreme values, as in original)
        initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
        return initial_preds

    def refined_emb_to_pred(self, features):
        refined_logits = self.refined_classifier(features)  # [B, total_tags]
        refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
        return refined_preds

    def forward(self, images):
        # 1. Feature extraction
        feats = self.backbone.features(images)  # [B, 1280, H/32, W/32] features
        feats = self.spatial_pool(feats).squeeze(-1).squeeze(-1)  # [B, 1280] global feature vector per image

        # 2. Initial tag prediction
        initial_logits = self.initial_classifier(feats)  # [B, total_tags]
        # Scale by temperature and clamp (to stabilize extreme values, as in original)
        initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
        initial_preds = self.initial_emb_to_pred(feats)

        # 3. Select top-k predicted tags for context (tag_context_size)
        probs = torch.sigmoid(initial_preds)  # convert logits to probabilities
@@ -164,12 +173,29 @@ class CamieTaggerRefined(nn.Module):
        combined_feature = torch.cat([feats, fused_feature], dim=1)  # [B, embedding_dim*2]

        # 8. Refined tag prediction
        refined_logits = self.refined_classifier(combined_feature)  # [B, total_tags]
        refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
        refined_preds = self.refined_emb_to_pred(combined_feature)

        return feats, initial_preds, combined_feature, refined_preds


class RefinedInitialEmbToPred(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model: CamieTaggerRefined = model

    def forward(self, features):
        return self.model.initial_emb_to_pred(features)


class RefinedRefinedEmbToPred(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model: CamieTaggerRefined = model

    def forward(self, features):
        return self.model.refined_emb_to_pred(features)


def create_refined_model():
    repo_id = 'Camais03/camie-tagger'
    filename = 'model_refined.safetensors'
@@ -194,11 +220,12 @@ def create_refined_model():
        filename=filename,
    )

    return model, created_at, (repo_id, filename)
    return model, created_at, (repo_id, filename), \
        (RefinedInitialEmbToPred(model), RefinedRefinedEmbToPred(model))


if __name__ == '__main__':
    model, created_at, _ = create_refined_model()
    model, created_at, _, _ = create_refined_model()
    model.eval()  # set to evaluation mode (disable dropout)
    print(model)

Loading