Commit 8a2b88f6 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update f1

parent 118610a4
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ def plt_confusion_matrix(ax, y_true, y_pred, title: str = 'Confusion Matrix',
    cm = confusion_matrix(y_true, y_pred, normalize=normalize)
    disp = ConfusionMatrixDisplay(
        confusion_matrix=cm,
        display_labels=['Diff', 'Sim'],
        display_labels=['Similar', 'Diff'],
    )
    disp.plot(ax=ax, cmap=cmap or plt.cm.Blues)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=90)
+2 −2
Original line number Diff line number Diff line
@@ -145,8 +145,8 @@ def create_plots(dist, cmatrix):
    threshold, _ = get_threshold_with_f1(pos, neg)
    plots = {}

    y_true = cmatrix.reshape(-1).type(torch.int).numpy()
    y_pred = (dist <= threshold).reshape(-1).type(torch.int).numpy()
    y_true = (~cmatrix).reshape(-1).type(torch.int).numpy()
    y_pred = (dist >= threshold).reshape(-1).type(torch.int).numpy()
    accuracy = (y_true == y_pred).sum() / y_true.shape[0]
    f1 = f1_score(y_true, y_pred, pos_label=1)
    precision = precision_score(y_true, y_pred, pos_label=1)