Commit 8ebf2e43 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): rename function

parent df0b4c2f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
from .lpips import extract_feature, lpips_difference, batch_lpips_difference
from .lpips import lpips_extract_feature, lpips_difference
+7 −6
Original line number Diff line number Diff line
@@ -27,8 +27,9 @@ def _lpips_feature_model():
    ))


def extract_feature(images: MultiImagesTyping) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    images = load_images(images)
def lpips_extract_feature(image: MultiImagesTyping) \
        -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    images = load_images(image)
    _encoded = np.stack([_image_encode(image) for image in images])
    features = _lpips_feature_model().run(["feat_0", "feat_1", "feat_2", "feat_3", "feat_4"], {'input': _encoded})
    return tuple(features)
@@ -46,7 +47,7 @@ _FEAT1_NAMES = ["feat_x_0", "feat_x_1", "feat_x_2", "feat_x_3", "feat_x_4"]
_FEAT2_NAMES = ["feat_y_0", "feat_y_1", "feat_y_2", "feat_y_3", "feat_y_4"]


def batch_lpips_difference(feats1: Tuple[np.ndarray, ...], feats2: Tuple[np.ndarray, ...]) -> np.ndarray:
def _batch_lpips_difference(feats1: Tuple[np.ndarray, ...], feats2: Tuple[np.ndarray, ...]) -> np.ndarray:
    output, = _lpips_diff_model().run(
        ["output"],
        {
@@ -64,20 +65,20 @@ def _auto_feat(img: AutoFeatTyping):
    if isinstance(img, (tuple, list)):
        return img
    else:
        return extract_feature(load_image(img))
        return lpips_extract_feature(load_image(img))


def lpips_difference(img1: AutoFeatTyping, img2: AutoFeatTyping) -> float:
    img1 = _auto_feat(img1)
    img2 = _auto_feat(img2)
    return batch_lpips_difference(img1, img2).item()
    return _batch_lpips_difference(img1, img2).item()


def lpips_clustering(images: MultiImagesTyping, threshold: float = 0.45) -> List[int]:
    images = load_images(images, mode='RGB')
    n = len(images)

    feat_list = [extract_feature(image) for image in tqdm(images, leave=False, desc='Extract features')]
    feat_list = [lpips_extract_feature(image) for image in tqdm(images, leave=False, desc='Extract features')]
    progress = tqdm(total=n * (n + 1) // 2, leave=False, desc='Metrics')

    @lru_cache(maxsize=n * (n + 1) // 2)