Commit a4453c8f authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): bug fix

parent a43cbf00
Loading
Loading
Loading
Loading
+9 −9
Original line number Diff line number Diff line
@@ -37,8 +37,8 @@ def plt_confusion_matrix(ax, y_true, y_pred, title: str = 'Confusion Matrix',


@keep_global_state()
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 1000,
                        samples: int = 1000, xrange: Tuple[float, float] = (0.0, 1.0), seed=0):
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 500,
                        samples: int = 500, xrange: Tuple[float, float] = (0.0, 1.0), seed=0):
    global_seed(seed)
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true
@@ -70,18 +70,18 @@ def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 1000,


def plt_f1_curve(ax, pos, neg, title='F1 Curve', units: int = 500,
                 xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'F1', f1_score, pos, neg, title, units, xrange)
                 samples: int = 500, xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'F1', f1_score, pos, neg, title, units, samples, xrange)


def plt_p_curve(ax, pos, neg, title='Precision Curve', units: int = 500,
                xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'precision', precision_score, pos, neg, title, units, xrange)
                samples: int = 500, xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'precision', precision_score, pos, neg, title, units, samples, xrange)


def plt_r_curve(ax, pos, neg, title='Recall Curve', units: int = 500,
                xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'recall', recall_score, pos, neg, title, units, xrange)
                samples: int = 500, xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'recall', recall_score, pos, neg, title, units, samples, xrange)


def plt_pr_curve(ax, pos, neg, title='PR Curve'):
@@ -114,7 +114,7 @@ def plt_roc_curve(ax, pos, neg, title: str = 'ROC Curve'):


@keep_global_state()
def get_threshold_with_f1(pos, neg, units: int = 1000, samples: int = 1000, seed: int = 0):
def get_threshold_with_f1(pos, neg, units: int = 500, samples: int = 500, seed: int = 0):
    global_seed(seed)
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true