Commit a9b11ed9 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update calculation

parent 8a2b88f6
Loading
Loading
Loading
Loading
+11 −9
Original line number Diff line number Diff line
@@ -37,15 +37,16 @@ def plt_confusion_matrix(ax, y_true, y_pred, title: str = 'Confusion Matrix',


@keep_global_state()
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 2000,
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 500,
                        xrange: Tuple[float, float] = (0.0, 1.0)):
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true
    xs, ys = [], []
    scores = np.sort(y_score, kind='heapsort')
    if len(scores) > units:
        scores = np.random.choice(scores, units)
    for score in np.sort(scores, kind='heapsort'):
        _y_pred = y_score >= score
        _y_pred = y_score <= score
        precision = func(y_true, _y_pred, zero_division=1)
        xs.append(score)
        ys.append(precision)
@@ -64,24 +65,24 @@ def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 2000,
    ax.legend()


def plt_f1_curve(ax, pos, neg, title='F1 Curve', units: int = 2000,
def plt_f1_curve(ax, pos, neg, title='F1 Curve', units: int = 500,
                 xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'F1', f1_score, pos, neg, title, units, xrange)


def plt_p_curve(ax, pos, neg, title='Precision Curve', units: int = 2000,
def plt_p_curve(ax, pos, neg, title='Precision Curve', units: int = 500,
                xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'precision', precision_score, pos, neg, title, units, xrange)


def plt_r_curve(ax, pos, neg, title='Recall Curve', units: int = 2000,
def plt_r_curve(ax, pos, neg, title='Recall Curve', units: int = 500,
                xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'recall', recall_score, pos, neg, title, units, xrange)


def plt_pr_curve(ax, pos, neg, title='PR Curve'):
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    precision, recall, _ = precision_recall_curve(1 - y_true, 1.0 - y_score)
    disp = PrecisionRecallDisplay(precision=precision, recall=recall)
    _map = -np.trapz(precision, recall)
    disp.plot(ax=ax, name=f'mAP {_map:.3f}')
@@ -95,7 +96,7 @@ def plt_pr_curve(ax, pos, neg, title='PR Curve'):

def plt_roc_curve(ax, pos, neg, title: str = 'ROC Curve'):
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    fpr, tpr, thresholds = roc_curve(1 - y_true, 1.0 - y_score)
    auc_value = auc(fpr, tpr)

    display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc_value)
@@ -108,14 +109,15 @@ def plt_roc_curve(ax, pos, neg, title: str = 'ROC Curve'):
    ax.legend()


def get_threshold_with_f1(pos, neg, units: int = 2000):
def get_threshold_with_f1(pos, neg, units: int = 500):
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    y_true = 1 - y_true
    xs, ys = [], []
    scores = np.sort(y_score, kind='heapsort')
    if len(scores) > units:
        scores = np.random.choice(scores, units)
    for score in np.sort(scores, kind='heapsort'):
        _y_pred = y_score >= score
        _y_pred = y_score <= score
        precision = f1_score(y_true, _y_pred, zero_division=1)
        xs.append(score)
        ys.append(precision)
+2 −4
Original line number Diff line number Diff line
@@ -145,16 +145,14 @@ def create_plots(dist, cmatrix):
    threshold, _ = get_threshold_with_f1(pos, neg)
    plots = {}

    y_true = (~cmatrix).reshape(-1).type(torch.int).numpy()
    y_pred = (dist >= threshold).reshape(-1).type(torch.int).numpy()
    accuracy = (y_true == y_pred).sum() / y_true.shape[0]
    y_true = cmatrix.reshape(-1).type(torch.int).numpy()
    y_pred = (dist <= threshold).reshape(-1).type(torch.int).numpy()
    f1 = f1_score(y_true, y_pred, pos_label=1)
    precision = precision_score(y_true, y_pred, pos_label=1)
    recall = recall_score(y_true, y_pred, pos_label=1)
    metrics = {
        'threshold': threshold,
        'f1_score': f1,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
    }