Commit 76234f02 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add support for safe metrics

parent 8e34a0fc
Loading
Loading
Loading
Loading
+9 −8
Original line number Diff line number Diff line
@@ -41,14 +41,15 @@ def _open_feat_model(model_name):


@lru_cache()
def _open_metric_model(model_name):
def _open_metric_model(model_name, safe: bool = False):
    return open_onnx_model(hf_hub_download(
        f'deepghs/imgutils-models',
        f'ccip/{model_name}_metrics.onnx',
        f'ccip/{model_name}_{"safe_" if safe else ""}metrics.onnx',
    ))


_VALID_MODEL_NAMES = [
    'ccip-caformer-5_fp32',
    'ccip-caformer-4_fp32',
    'ccip-caformer-2_fp32',
]
@@ -81,19 +82,19 @@ def _preprocess_feats(x, size: int = 384, model_name: str = _DEFAULT_MODEL_NAMES
_FeatureOrImage = Union[ImageTyping, np.ndarray]


def get_ccip_similarity(x: _FeatureOrImage, y: _FeatureOrImage,
def get_ccip_similarity(x: _FeatureOrImage, y: _FeatureOrImage, safe: bool = False,
                        size: int = 384, model_name: str = _DEFAULT_MODEL_NAMES) -> float:
    return batch_ccip_similarity([x, y], size, model_name)[0, 1].item()
    return batch_ccip_similarity([x, y], safe, size, model_name)[0, 1].item()


def batch_ccip_similarity(images: Union[np.ndarray, List[_FeatureOrImage]],
def batch_ccip_similarity(images: Union[np.ndarray, List[_FeatureOrImage]], safe: bool = False,
                          size: int = 384, model_name: str = _DEFAULT_MODEL_NAMES):
    input_ = _preprocess_feats(images, size, model_name).astype(np.float32)
    output, = _open_metric_model(model_name).run(['output'], {'input': input_})
    output, = _open_metric_model(model_name, safe=safe).run(['output'], {'input': input_})
    return output


def ccip_clustering(images: MultiImagesTyping, threshold: float = 0.6,
def ccip_clustering(images: MultiImagesTyping, threshold: float = 0.6, safe: bool = True,
                    size: int = 384, model_name: str = _DEFAULT_MODEL_NAMES):
    images = load_images(images, mode='RGB')
    features = []
@@ -103,7 +104,7 @@ def ccip_clustering(images: MultiImagesTyping, threshold: float = 0.6,
    if not features:
        return []
    feats = np.stack(features)
    differences = 1 - batch_ccip_similarity(feats, size, model_name)
    differences = 1 - batch_ccip_similarity(feats, safe, size, model_name)

    def _metric(x, y):
        return differences[int(x), int(y)]