Commit 4191a5ea authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add ccip_merge function

parent 75667b88
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -69,3 +69,11 @@ ccip_clustering



ccip_merge
--------------------------------------------

.. autofunction:: ccip_merge



+38 −0
Original line number Diff line number Diff line
@@ -49,6 +49,8 @@ __all__ = [

    'ccip_default_clustering_params',
    'ccip_clustering',

    'ccip_merge',
]


@@ -517,3 +519,39 @@ def ccip_clustering(images: List[_FeatureOrImage], method: CCIPClusterMethodTypi
        assert False, f'Unknown mode for CCIP clustering - {method!r}.'  # pragma: no cover

    return clustering.labels_.tolist()


def ccip_merge(images: Union[List[_FeatureOrImage], np.ndarray],
               size: int = 384, model: str = _DEFAULT_MODEL_NAMES) -> np.ndarray:
    """
    Merge multiple feature vectors into a single vector.

    :param images: The feature vectors or images to merge.
    :type images: Union[List[_FeatureOrImage], numpy.ndarray]
    :param size: The size of the image. (default: 384)
    :type size: int
    :param model: The name of the model. (default: ``ccip-caformer-24-randaug-pruned``)
    :type model: str
    :return: The merged feature vector.
    :rtype: numpy.ndarray

    Examples::
        >>> from imgutils.metrics import ccip_merge, ccip_batch_differences
        >>>
        >>> images = [f'ccip/{i}.jpg' for i in range(1, 4)]
        >>>
        >>> merged = ccip_merge(images)
        >>> merged.shape
        (768,)
        >>>
        >>> diffs = ccip_batch_differences([merged, *images])[0, 1:]
        >>> diffs
        array([0.07437477, 0.0356068 , 0.04396922], dtype=float32)
        >>> diffs.mean()
        0.05131693
    """
    embs = np.stack([_p_feature(img, size, model) for img in images]).astype(np.float32)
    lengths = np.linalg.norm(embs, axis=-1)
    embs = embs / lengths.reshape(-1, 1)
    ret_embedding = embs.mean(axis=0)
    return ret_embedding / np.linalg.norm(ret_embedding) * lengths.mean()
+60 −2
Original line number Diff line number Diff line
import glob
import json
import os.path
from typing import List, Tuple
from functools import lru_cache
from typing import List, Tuple, Dict, Iterator

import numpy as np
import pytest
from hbutils.testing import disable_output
from huggingface_hub import HfFileSystem, HfApi
from natsort import natsorted
from sklearn.metrics import adjusted_rand_score

from imgutils.metrics import ccip_difference, ccip_default_threshold, ccip_extract_feature, ccip_same, ccip_batch_same, \
    ccip_clustering
    ccip_clustering, ccip_merge, ccip_batch_differences
from test.testings import get_testfile


@@ -99,6 +102,52 @@ def s_threshold(threshold) -> float:
    return threshold + 0.05


MERGE_TAGS = [
    'little_red_riding_hood_(grimm)',
    'maria_cadenzavna_eve',
    'misaka_mikoto',
    'dido_(azur_lane)',
    'hina_(dress)_(blue_archive)',
    'warspite_(kancolle)',
    'kallen_kaslana',
    "kal'tsit_(arknights)",
    'anastasia_(fate)',
    "m16a1_(girls'_frontline)",
]

hf_fs = HfFileSystem(token=os.environ.get('HF_TOKEN'))
hf_client = HfApi(token=os.environ.get('HF_TOKEN'))
SRC_REPO = 'deepghs/character_index'


@lru_cache()
def _get_source_list() -> List[dict]:
    return json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/characters.json'))


@lru_cache()
def _get_source_dict() -> Dict[str, dict]:
    return {item['tag']: item for item in _get_source_list()}


def list_character_tags() -> Iterator[str]:
    for item in _get_source_list():
        yield item['tag']


def get_detailed_character_info(tag: str) -> dict:
    return _get_source_dict()[tag]


def get_np_feats(tag):
    item = get_detailed_character_info(tag)
    return np.load(hf_client.hf_hub_download(
        repo_id=SRC_REPO,
        repo_type='dataset',
        filename=f'{item["hprefix"]}/{item["short_tag"]}/feat.npy'
    ))


@pytest.mark.unittest
class TestMetricCCIP:
    def test_ccip_difference(self, img_1, img_2, img_3, img_4, img_5, img_6, img_7, s_threshold):
@@ -159,3 +208,12 @@ class TestMetricCCIP:

        with pytest.raises(KeyError):
            _ = ccip_clustering(images_12, min_samples=2, method='what_the_fxxk')

    @pytest.mark.parametrize(['tag'], [
        (tag,) for tag in MERGE_TAGS
    ])
    def test_ccip_merge(self, tag):
        feats = get_np_feats(tag)
        merged_emb = ccip_merge(feats)
        assert ccip_batch_differences([merged_emb, *feats])[0, 1:].mean() <= 0.085
        assert ccip_batch_same([merged_emb, *feats])[0, 1:].all()