Commit 118610a4 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update f1

parent 7704e624
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ def plt_confusion_matrix(ax, y_true, y_pred, title: str = 'Confusion Matrix',


@keep_global_state()
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 500,
def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 2000,
                        xrange: Tuple[float, float] = (0.0, 1.0)):
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    xs, ys = [], []
@@ -64,17 +64,17 @@ def _create_score_curve(ax, name, func, pos, neg, title=None, units: int = 500,
    ax.legend()


def plt_f1_curve(ax, pos, neg, title='F1 Curve', units: int = 500,
def plt_f1_curve(ax, pos, neg, title='F1 Curve', units: int = 2000,
                 xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'F1', f1_score, pos, neg, title, units, xrange)


def plt_p_curve(ax, pos, neg, title='Precision Curve', units: int = 500,
def plt_p_curve(ax, pos, neg, title='Precision Curve', units: int = 2000,
                xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'precision', precision_score, pos, neg, title, units, xrange)


def plt_r_curve(ax, pos, neg, title='Recall Curve', units: int = 500,
def plt_r_curve(ax, pos, neg, title='Recall Curve', units: int = 2000,
                xrange: Tuple[float, float] = (0.0, 1.0)):
    _create_score_curve(ax, 'recall', recall_score, pos, neg, title, units, xrange)

@@ -108,7 +108,7 @@ def plt_roc_curve(ax, pos, neg, title: str = 'ROC Curve'):
    ax.legend()


def get_threshold_with_f1(pos, neg, units: int = 500):
def get_threshold_with_f1(pos, neg, units: int = 2000):
    y_true, y_score = _pos_neg_to_true_score(pos, neg)
    xs, ys = [], []
    scores = np.sort(y_score, kind='heapsort')
+17 −16
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd
from lighttuner.hpo import hpo, R, uniform, randint
from natsort import natsorted
from sklearn.cluster import OPTICS, DBSCAN
from sklearn.metrics import adjusted_rand_score, precision_score, recall_score
from sklearn.metrics import adjusted_rand_score, precision_score, recall_score, f1_score
from torchvision import transforms
from tqdm.auto import tqdm

@@ -142,22 +142,23 @@ def create_plots(dist, cmatrix):
    # in plot function, score of pos samples should be greater than neg samples
    # so pos and neg should be reversed here!!!
    pos, neg = dist[~cmatrix].numpy(), dist[cmatrix].numpy()
    threshold, f1_score = get_threshold_with_f1(pos, neg)
    threshold, _ = get_threshold_with_f1(pos, neg)
    plots = {}

    y_true = cmatrix.reshape(-1).type(torch.int).numpy()
    y_pred = (dist <= threshold).reshape(-1).type(torch.int).numpy()
    accuracy = (y_true == y_pred).sum() / y_true.shape[0]
    f1 = f1_score(y_true, y_pred, pos_label=1)
    precision = precision_score(y_true, y_pred, pos_label=1)
    recall = recall_score(y_true, y_pred, pos_label=1)
    metrics = {
        'threshold': threshold,
        'f1_score': f1_score,
        'f1_score': f1,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
    }
    logging.info(f'Threshold: {threshold:.4f}, f1 score: {f1_score:.4f}, '
    logging.info(f'Threshold: {threshold:.4f}, f1 score: {f1:.4f}, '
                 f'precision: {precision * 100.0:.2f}%, recall: {recall * 100.0:.2f}%')

    logging.info('Creating confusion matrix ...')
@@ -208,18 +209,18 @@ def export_model_to_dir(file_in_repo: str, output_dir: str, repository: str = 'd
    with open(metrics_file, 'w') as f:
        json.dump(metrics, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    clustering_file = os.path.join(output_dir, 'cluster.json')
    logging.info(f'Creating clustering measurement {clustering_file!r} ...')
    c_results = {}
    for cname, method, xrange in [
        ('dbscan_free', 'dbscan', (2, 5)),
        ('dbscan_2', 'dbscan', (2, 2)),
        ('optics', 'optics', (2, 5)),
    ]:
        params, score = clustering_metrics(dist, cids, method=method, min_samples_range=xrange)
        c_results[cname] = {**params, 'score': score}
    with open(clustering_file, 'w') as f:
        json.dump(c_results, fp=f, indent=4, sort_keys=True, ensure_ascii=False)
    # clustering_file = os.path.join(output_dir, 'cluster.json')
    # logging.info(f'Creating clustering measurement {clustering_file!r} ...')
    # c_results = {}
    # for cname, method, xrange in [
    #     ('dbscan_free', 'dbscan', (2, 5)),
    #     ('dbscan_2', 'dbscan', (2, 2)),
    #     ('optics', 'optics', (2, 5)),
    # ]:
    #     params, score = clustering_metrics(dist, cids, method=method, min_samples_range=xrange)
    #     c_results[cname] = {**params, 'score': score}
    # with open(clustering_file, 'w') as f:
    #     json.dump(c_results, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    for name, img in plots.items():
        plt_file = os.path.join(output_dir, f'plt_{name}.png')