Commit 63ae589c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add clustering test

parent a9b11ed9
Loading
Loading
Loading
Loading
+12 −12
Original line number Diff line number Diff line
@@ -207,18 +207,18 @@ def export_model_to_dir(file_in_repo: str, output_dir: str, repository: str = 'd
    with open(metrics_file, 'w') as f:
        json.dump(metrics, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    # clustering_file = os.path.join(output_dir, 'cluster.json')
    # logging.info(f'Creating clustering measurement {clustering_file!r} ...')
    # c_results = {}
    # for cname, method, xrange in [
    #     ('dbscan_free', 'dbscan', (2, 5)),
    #     ('dbscan_2', 'dbscan', (2, 2)),
    #     ('optics', 'optics', (2, 5)),
    # ]:
    #     params, score = clustering_metrics(dist, cids, method=method, min_samples_range=xrange)
    #     c_results[cname] = {**params, 'score': score}
    # with open(clustering_file, 'w') as f:
    #     json.dump(c_results, fp=f, indent=4, sort_keys=True, ensure_ascii=False)
    clustering_file = os.path.join(output_dir, 'cluster.json')
    logging.info(f'Creating clustering measurement {clustering_file!r} ...')
    c_results = {}
    for cname, method, xrange in [
        ('dbscan_free', 'dbscan', (2, 5)),
        ('dbscan_2', 'dbscan', (2, 2)),
        ('optics', 'optics', (2, 5)),
    ]:
        params, score = clustering_metrics(dist, cids, method=method, min_samples_range=xrange)
        c_results[cname] = {**params, 'score': score}
    with open(clustering_file, 'w') as f:
        json.dump(c_results, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    for name, img in plots.items():
        plt_file = os.path.join(output_dir, f'plt_{name}.png')