Commit 24ca0f2d authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): try fix this

parent 63ae589c
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ from typing import Tuple
import numpy as np
import torch
from PIL import Image
from hbutils.random import keep_global_state
from hbutils.random import keep_global_state, global_seed
from hbutils.system import TemporaryDirectory
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
@@ -37,8 +37,9 @@ def plt_confusion_matrix(ax, y_true, y_pred, title: str = 'Confusion Matrix',


@keep_global_state()
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 500,
                        xrange: Tuple[float, float] = (0.0, 1.0)):
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 2000,
                        xrange: Tuple[float, float] = (0.0, 1.0), seed=0):
    global_seed(seed)
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true
    xs, ys = [], []
@@ -109,7 +110,9 @@ def plt_roc_curve(ax, pos, neg, title: str = 'ROC Curve'):
    ax.legend()


def get_threshold_with_f1(pos, neg, units: int = 500):
@keep_global_state()
def get_threshold_with_f1(pos, neg, units: int = 2000, seed: int = 0):
    global_seed(seed)
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true
    xs, ys = [], []
+12 −12
Original line number Diff line number Diff line
@@ -207,18 +207,18 @@ def export_model_to_dir(file_in_repo: str, output_dir: str, repository: str = 'd
    with open(metrics_file, 'w') as f:
        json.dump(metrics, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    clustering_file = os.path.join(output_dir, 'cluster.json')
    logging.info(f'Creating clustering measurement {clustering_file!r} ...')
    c_results = {}
    for cname, method, xrange in [
        ('dbscan_free', 'dbscan', (2, 5)),
        ('dbscan_2', 'dbscan', (2, 2)),
        ('optics', 'optics', (2, 5)),
    ]:
        params, score = clustering_metrics(dist, cids, method=method, min_samples_range=xrange)
        c_results[cname] = {**params, 'score': score}
    with open(clustering_file, 'w') as f:
        json.dump(c_results, fp=f, indent=4, sort_keys=True, ensure_ascii=False)
    # clustering_file = os.path.join(output_dir, 'cluster.json')
    # logging.info(f'Creating clustering measurement {clustering_file!r} ...')
    # c_results = {}
    # for cname, method, xrange in [
    #     ('dbscan_free', 'dbscan', (2, 5)),
    #     ('dbscan_2', 'dbscan', (2, 2)),
    #     ('optics', 'optics', (2, 5)),
    # ]:
    #     params, score = clustering_metrics(dist, cids, method=method, min_samples_range=xrange)
    #     c_results[cname] = {**params, 'score': score}
    # with open(clustering_file, 'w') as f:
    #     json.dump(c_results, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    for name, img in plots.items():
        plt_file = os.path.join(output_dir, f'plt_{name}.png')