Commit 09ddd40f authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add min_samples argument

parent ff54af8b
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -49,6 +49,7 @@ def _open_metric_model(model_name, safe: bool = False):


_VALID_MODEL_NAMES = [
    # 'ccip-caformer-23_randaug_fp32',
    'ccip-caformer-5_fp32',
    'ccip-caformer-4_fp32',
    'ccip-caformer-2_fp32',
@@ -94,7 +95,7 @@ def batch_ccip_similarity(images: Union[np.ndarray, List[_FeatureOrImage]], safe
    return output


def ccip_clustering(images: MultiImagesTyping, threshold: float = 0.6, safe: bool = True,
def ccip_clustering(images: MultiImagesTyping, threshold: float = 0.6, min_samples: int = 2, safe: bool = True,
                    size: int = 384, model_name: str = _DEFAULT_MODEL_NAMES):
    images = load_images(images, mode='RGB')
    features = []
@@ -110,5 +111,5 @@ def ccip_clustering(images: MultiImagesTyping, threshold: float = 0.6, safe: boo
        return differences[int(x), int(y)]

    samples = np.array(range(len(images))).reshape(-1, 1)
    clustering = DBSCAN(eps=1 - threshold, min_samples=2, metric=_metric).fit(samples)
    clustering = DBSCAN(eps=1 - threshold, min_samples=min_samples, metric=_metric).fit(samples)
    return clustering.labels_.tolist()