Commit a43cbf00 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): try fix this

parent 24ca0f2d
Loading
Loading
Loading
Loading
+9 −3
Original line number Diff line number Diff line
@@ -37,8 +37,8 @@ def plt_confusion_matrix(ax, y_true, y_pred, title: str = 'Confusion Matrix',


@keep_global_state()
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 2000,
                        xrange: Tuple[float, float] = (0.0, 1.0), seed=0):
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 1000,
                        samples: int = 1000, xrange: Tuple[float, float] = (0.0, 1.0), seed=0):
    global_seed(seed)
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true
@@ -46,6 +46,9 @@ def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 2000,
    scores = np.sort(y_score, kind='heapsort')
    if len(scores) > units:
        scores = np.random.choice(scores, units)
    sps = np.linspace(y_score.min(), y_score.max(), samples)
    scores = np.concatenate([sps, scores])

    for score in np.sort(scores, kind='heapsort'):
        _y_pred = y_score <= score
        precision = func(y_true, _y_pred, zero_division=1)
@@ -111,7 +114,7 @@ def plt_roc_curve(ax, pos, neg, title: str = 'ROC Curve'):


@keep_global_state()
def get_threshold_with_f1(pos, neg, units: int = 2000, seed: int = 0):
def get_threshold_with_f1(pos, neg, units: int = 1000, samples: int = 1000, seed: int = 0):
    global_seed(seed)
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true
@@ -119,6 +122,9 @@ def get_threshold_with_f1(pos, neg, units: int = 2000, seed: int = 0):
    scores = np.sort(y_score, kind='heapsort')
    if len(scores) > units:
        scores = np.random.choice(scores, units)
    sps = np.linspace(y_score.min(), y_score.max(), samples)
    scores = np.concatenate([sps, scores])

    for score in np.sort(scores, kind='heapsort'):
        _y_pred = y_score <= score
        precision = f1_score(y_true, _y_pred, zero_division=1)