Commit 7704e624 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update plots

parent e317a073
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@ def _pos_neg_to_true_score(pos, neg):


def plt_confusion_matrix(ax, y_true, y_pred, title: str = 'Confusion Matrix',
                         normalize: Literal['true', 'pred', None] = None, cmap=None):
                         normalize: Literal['true', 'pred', None] = 'true', cmap=None):
    cm = confusion_matrix(y_true, y_pred, normalize=normalize)
    disp = ConfusionMatrixDisplay(
        confusion_matrix=cm,
+21 −11
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from huggingface_hub import hf_hub_download, HfApi, CommitOperationAdd
from lighttuner.hpo import hpo, R, uniform, randint
from natsort import natsorted
from sklearn.cluster import OPTICS, DBSCAN
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import adjusted_rand_score, precision_score, recall_score
from torchvision import transforms
from tqdm.auto import tqdm

@@ -148,12 +148,26 @@ def create_plots(dist, cmatrix):
    y_true = cmatrix.reshape(-1).type(torch.int).numpy()
    y_pred = (dist <= threshold).reshape(-1).type(torch.int).numpy()
    accuracy = (y_true == y_pred).sum() / y_true.shape[0]
    logging.info(f'Threshold: {threshold:.4f}, f1 score: {f1_score:.4f}, accuracy: {accuracy * 100.0:.2f}%')
    precision = precision_score(y_true, y_pred, pos_label=1)
    recall = recall_score(y_true, y_pred, pos_label=1)
    metrics = {
        'threshold': threshold,
        'f1_score': f1_score,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
    }
    logging.info(f'Threshold: {threshold:.4f}, f1 score: {f1_score:.4f}, '
                 f'precision: {precision * 100.0:.2f}%, recall: {recall * 100.0:.2f}%')

    logging.info('Creating confusion matrix ...')
    plots['confusion_matrix'] = plt_export(
    plots['confusion_matrix_true'] = plt_export(
        plt_confusion_matrix, y_true, y_pred,
        normalize='true', title=f'Confusion Matrix (True)'
    )
    plots['confusion_matrix_pred'] = plt_export(
        plt_confusion_matrix, y_true, y_pred,
        title=f'Confusion Matrix\nAccuracy: {accuracy * 100.0:.2f}%'
        normalize='pred', title=f'Confusion Matrix (Predict)'
    )

    logging.info('Creating ROC Curve ...')
@@ -171,7 +185,7 @@ def create_plots(dist, cmatrix):
    logging.info('Creating F1 Curve ...')
    plots['f1'] = plt_export(plt_f1_curve, pos, neg)

    return (threshold, f1_score, accuracy), plots
    return threshold, metrics, plots


def export_model_to_dir(file_in_repo: str, output_dir: str, repository: str = 'deepghs/ccip',
@@ -188,15 +202,11 @@ def export_model_to_dir(file_in_repo: str, output_dir: str, repository: str = 'd

    scale = get_scale_for_model(model)
    dist, cids, cmatrix = _get_dist_matrix(model, scale)
    (threshold, f1_score, accuracy), plots = create_plots(dist, cmatrix)
    threshold, metrics, plots = create_plots(dist, cmatrix)
    metrics_file = os.path.join(output_dir, 'metrics.json')
    logging.info(f'Creating metric file {metrics_file!r} ...')
    with open(metrics_file, 'w') as f:
        json.dump({
            'threshold': threshold,
            'f1_score': f1_score,
            'accuracy': accuracy,
        }, fp=f, indent=4, sort_keys=True, ensure_ascii=False)
        json.dump(metrics, fp=f, indent=4, sort_keys=True, ensure_ascii=False)

    clustering_file = os.path.join(output_dir, 'cluster.json')
    logging.info(f'Creating clustering measurement {clustering_file!r} ...')