Commit f1cdcf92 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): export epoch5

parent 01be848b
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ from typing import Union, List
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from sklearn.cluster import DBSCAN
from tqdm.auto import tqdm

from ..data import MultiImagesTyping, load_images, ImageTyping
from ..utils import open_onnx_model
@@ -12,6 +14,7 @@ __all__ = [
    'get_ccip_features',
    'get_ccip_similarity',
    'batch_ccip_similarity',
    'ccip_clustering',
]


@@ -88,3 +91,23 @@ def batch_ccip_similarity(images: Union[np.ndarray, List[_FeatureOrImage]],
    input_ = _preprocess_feats(images, size, model_name).astype(np.float32)
    output, = _open_metric_model(model_name).run(['output'], {'input': input_})
    return output


def ccip_clustering(images: MultiImagesTyping, threshold: float = 0.6,
                    size: int = 384, model_name: str = _DEFAULT_MODEL_NAMES):
    images = load_images(images, mode='RGB')
    features = []
    for image in tqdm(images, desc='Feature Extract'):
        features.append(get_ccip_features([image], size, model_name)[0])

    if not features:
        return []
    feats = np.stack(features)
    differences = 1 - batch_ccip_similarity(feats, size, model_name)

    def _metric(x, y):
        return differences[int(x), int(y)]

    samples = np.array(range(len(images))).reshape(-1, 1)
    clustering = DBSCAN(eps=1 - threshold, min_samples=2, metric=_metric).fit(samples)
    return clustering.labels_.tolist()
+3 −2
Original line number Diff line number Diff line
@@ -148,8 +148,9 @@ def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = Fal


MODELS = [
    ('caformer', 'ccip-caformer-2_fp32.ckpt'),
    ('caformer', 'ccip-caformer-4_fp32.ckpt'),
    # ('caformer', 'ccip-caformer-2_fp32.ckpt'),
    # ('caformer', 'ccip-caformer-4_fp32.ckpt'),
    ('caformer', 'ccip-caformer-5_fp32.ckpt'),
]