Commit 52a01093 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save this

parent 33eef61f
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -19,3 +19,5 @@ regex
git+https://github.com/openai/CLIP.git
ultralytics
controlnet_aux
lighttuner
natsort
 No newline at end of file
+77 −2
Original line number Diff line number Diff line
@@ -4,18 +4,28 @@ import json
import os
import shutil
from functools import partial
from typing import Tuple

import click
import numpy as np
import torch
from PIL import Image
from ditk import logging
from hbutils.system import TemporaryDirectory
from hbutils.testing import disable_output
from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd
from lighttuner.hpo import hpo, R, uniform, randint
from natsort import natsorted
from sklearn.cluster import OPTICS, DBSCAN
from sklearn.metrics import adjusted_rand_score
from torchvision import transforms
from tqdm.auto import tqdm

try:
    from typing import Literal
except (ImportError, ModuleNotFoundError):
    from typing_extensions import Literal

from test.testings import get_testfile
from zoo.ccip.demo import _get_model_from_ckpt
from .dataset import TEST_TRANSFORM
@@ -73,7 +83,59 @@ def _get_dist_matrix(model: CCIP, scale: float, batch: int = 32):
    cids = torch.tensor(cids)
    cmatrix = cids == cids.reshape(-1, 1)

    return dist, cmatrix
    return dist, cids, cmatrix


def clustering_metrics(dist, cids, method: Literal['dbscan', 'optics'] = 'dbscan',
                       init_steps: int = 40, max_steps: int = 200,
                       min_samples_range: Tuple[int, int] = (2, 5)):
    assert method in {'dbscan', 'optics'}, f'Method {method!r} not found.'

    def _trans_id(x):
        max_id = 0
        _maps = {}
        retval = []
        for item in x:
            if item == -1:
                retval.append(max_id)
                max_id += 1
            else:
                if item not in _maps:
                    _maps[item] = max_id
                    max_id += 1
                retval.append(_maps[item])
        return retval

    @hpo
    def opt_func(v):  # this function is still usable after decorating
        min_samples, eps = v['min_samples'], v['eps']

        def _metric(x, y):
            return dist[int(x), int(y)].item()

        samples = np.array(range(cids.shape[0])).reshape(-1, 1)
        if method == 'dbscan':
            clustering = DBSCAN(eps=eps, min_samples=min_samples, metric=_metric).fit(samples)
        elif method == 'optics':
            clustering = OPTICS(max_eps=eps, min_samples=min_samples, metric=_metric).fit(samples)
        else:
            assert False, 'Should not reach here!'

        ret_ids = clustering.labels_.tolist()
        logging.info(f'Cluster result: {ret_ids!r}')
        return adjusted_rand_score(cids, _trans_id(ret_ids))

    logging.info('Waiting for HPO ...')
    params, score, _ = opt_func.bayes() \
        .init_steps(init_steps) \
        .max_steps(max_steps) \
        .maximize(R).max_workers(1).rank(10) \
        .spaces({
        'min_samples': randint(*min_samples_range),
        'eps': uniform(0.0, 0.5),
    }).run()

    return params, score


def create_plots(dist, cmatrix):
@@ -125,7 +187,7 @@ def export_model_to_dir(file_in_repo: str, output_dir: str, repository: str = 'd
    shutil.copyfile(ckpt_file, model_ckpt_file)

    scale = get_scale_for_model(model)
    dist, cmatrix = _get_dist_matrix(model, scale)
    dist, cids, cmatrix = _get_dist_matrix(model, scale)
    (threshold, f1_score, accuracy), plots = create_plots(dist, cmatrix)
    metrics_file = os.path.join(output_dir, 'metrics.json')
    logging.info(f'Creating metric file {metrics_file!r} ...')
@@ -136,6 +198,19 @@ def export_model_to_dir(file_in_repo: str, output_dir: str, repository: str = 'd
            'accuracy': accuracy,
        }, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    clustering_file = os.path.join(output_dir, 'cluster.json')
    logging.info(f'Creating clustering measurement {clustering_file!r} ...')
    c_results = {}
    for cname, method, xrange in [
        ('dbscan_free', 'dbscan', (2, 5)),
        ('dbscan_2', 'dbscan', (2, 2)),
        ('optics', 'optics', (2, 5)),
    ]:
        params, score = clustering_metrics(dist, cids, method=method, min_samples_range=xrange)
        c_results[cname] = {**params, 'score': score}
    with open(clustering_file, 'w') as f:
        json.dump(c_results, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    for name, img in plots.items():
        plt_file = os.path.join(output_dir, f'plt_{name}.png')
        logging.info(f'Saving plotting file {plt_file!r} ...')