Commit 1fb954ee authored by dzy7e's avatar dzy7e
Browse files

optimize dataset loading

parent 946b1481
Loading
Loading
Loading
Loading
+64 −79
Original line number Diff line number Diff line
@@ -10,13 +10,14 @@ from hbutils.random import global_seed
from sklearn import svm
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import AUROC, AveragePrecision
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, FastCharacterDataset, TEST_TRANSFORM, char_collect_fn
from .loss import NTXentLoss, MLCELoss
from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, FastCharacterDataset, TEST_TRANSFORM, char_collect_fn
from .loss import MLCELoss
from .model import CCIP
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

@@ -26,7 +27,6 @@ _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts')

_CKPT_PATTERN = re.compile(r'^ccip-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$')


def _find_latest_ckpt(name: str) -> Optional[str]:
    if os.path.exists(_CKPT_DIR):
        ckpts = []
@@ -43,7 +43,6 @@ def _find_latest_ckpt(name: str) -> Optional[str]:
    else:
        return None


def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
    if filename is not None:
        matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename))
@@ -54,7 +53,6 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
    else:
        return None


def _sample_analysis(poss, negs, svm_samples: int = 10000):
    poss_cnt, negs_cnt = poss.shape[0], negs.shape[0]
    total = poss_cnt+negs_cnt
@@ -72,18 +70,12 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):
    model.fit(features.reshape(-1, 1), labels)
    predictions = model.predict(features.reshape(-1, 1))

    coef = model.coef_.reshape(-1)[0].tolist()
    inter = model.intercept_.reshape(-1)[0].tolist()
    threshold = -inter / coef

    return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), \
           threshold, accuracy_score(labels, predictions)

    return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), accuracy_score(labels, predictions)

def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30,
          learning_rate: float = 0.001, weight_decay: float = 1e-2, tau: float = 0.15,
          save_per_epoch: int = 10, eval_epoch: int = 5, num_workers=8,
          save_per_epoch: int = 10, eval_epoch: int = 5, log_iter: int = 500, num_workers=8,
          model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0):
    if seed is not None:
        # native random, numpy, torch and faker's seeds are includes
@@ -148,10 +140,13 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    model, optimizer, train_dataloader, test_dataloader, scheduler = \
        accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)

    metric_auroc = AUROC(task="binary")
    metric_ap = AveragePrecision(task="binary")

    for epoch in range(previous_epoch+1, max_epochs+1):
        running_loss = 0.0
        train_pos_total = 0
        positive_sims, negative_sims = [], []
        pred_list, gt_list = [], []
        model.train()
        for i, (inputs, char_ids) in enumerate(tqdm(train_dataloader)):
            train_dataloader.dataset.reset()
@@ -174,72 +169,62 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            scheduler.step()

            running_loss += loss.item()*len(char_ids)
            train_pos_total += len(char_ids)

            gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
            mask = torch.ones_like(outputs)
            mask -= torch.diag_embed(torch.diag(mask))
            outputs = outputs.detach().cpu()
            gt_diag0 = gt.clone()
            gt_diag0.diagonal().copy_(torch.zeros(len(char_ids)))
            # outputs.diagonal().copy_(torch.ones(len(char_ids))*-10000)
            # max_idxs = outputs.argsort(dim=-1)
            # for max_idx, n_pos in zip(max_idxs, gt.sum(dim=1)):
            #     train_pos_total += n_pos
            #     positive_sims.append(outputs[labels])
            #     negative_sims.append(outputs[~labels])
            train_pos_total += gt_diag0.sum()
            positive_sims.append(outputs[gt_diag0])
            negative_sims.append(outputs[~gt])

        epoch_loss = running_loss #/ train_pos_total
        train_psims = torch.cat(positive_sims)
        train_nsims = torch.cat(negative_sims)
        train_pos_mean, train_pos_std, train_neg_mean, train_neg_std, train_threshold, train_acc_svm = \
            _sample_analysis(train_psims, train_nsims)
            gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
            pred_list.append(outputs[mask])
            gt_list.append(gt_same.long()[mask])

            if (i+1)%log_iter == 0:
                pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list)))
                if accelerator.is_local_main_process:
            logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, '
                         f'acc_svm: {train_acc_svm:.6f}, threshold: {train_threshold:.6f}.')
                    auc = metric_auroc(pred_t, gt_t).item()
                    ap = metric_ap(pred_t, gt_t).item()
                    mean_loss = running_loss/train_pos_total
                    logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.')
                    if writer:
                writer.add_scalar('train/loss', epoch_loss, epoch)
                writer.add_scalar('train/pos/mean', train_pos_mean, epoch)
                writer.add_scalar('train/pos/std', train_pos_std, epoch)
                writer.add_scalar('train/neg/mean', train_neg_mean, epoch)
                writer.add_scalar('train/neg/std', train_neg_std, epoch)
                writer.add_scalar('train/threshold', train_threshold, epoch)
                writer.add_scalar('train/acc_svm', train_acc_svm, epoch)
                        writer.add_scalar('train/loss', mean_loss, epoch)
                        writer.add_scalar('train/auc', auc, epoch)
                        writer.add_scalar('train/ap', auc, epoch)

                    pred_list.clear()
                    gt_list.clear()
                    running_loss = 0.
                    train_pos_total = 0

        model.eval()
        if epoch%eval_epoch == 0:
            with torch.no_grad():
                positive_sims, negative_sims = [], []
                pred_list, gt_list = [], []
                for i, (inputs, char_ids) in enumerate(tqdm(test_dataloader)):
                    inputs = inputs.to(accelerator.device)  # BxCxHxW
                    char_ids = char_ids.to(accelerator.device)  # B

                    outputs = model(inputs)  # BxB

                    gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
                    mask = torch.ones_like(outputs)
                    mask -= torch.diag_embed(torch.diag(mask))
                    outputs = outputs.detach().cpu()
                    gt_diag0 = gt.clone()
                    gt_diag0.diagonal().copy_(torch.zeros(len(char_ids)))
                    train_pos_total += gt_diag0.sum()
                    positive_sims.append(outputs[gt_diag0])
                    negative_sims.append(outputs[~gt])

                test_psims = torch.cat(positive_sims)
                test_nsims = torch.cat(negative_sims)
                test_pos_mean, test_pos_std, test_neg_mean, test_neg_std, test_threshold, test_acc_svm = \
                    _sample_analysis(test_psims, test_nsims)
                    gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
                    pred_list.append(outputs[mask])
                    gt_list.append(gt_same.long()[mask])

                pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list)))
                if accelerator.is_local_main_process:
                    logging.info(f'Epoch {epoch}, '
                                 f'acc_svm: {test_acc_svm:.6f}, threshold: {test_threshold:.6f}')
                    auc = metric_auroc(pred_t, gt_t).item()
                    ap = metric_ap(pred_t, gt_t).item()
                    mean_loss = running_loss/train_pos_total
                    logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.')
                    if writer:
                        writer.add_scalar('test/pos/mean', test_pos_mean, epoch)
                        writer.add_scalar('test/pos/std', test_pos_std, epoch)
                        writer.add_scalar('test/neg/mean', test_neg_mean, epoch)
                        writer.add_scalar('test/neg/std', test_neg_std, epoch)
                        writer.add_scalar('test/threshold', test_threshold, epoch)
                        writer.add_scalar('test/acc_svm', test_acc_svm, epoch)
                        writer.add_scalar('test/loss', mean_loss, epoch)
                        writer.add_scalar('test/auc', auc, epoch)
                        writer.add_scalar('test/ap', auc, epoch)

                    pred_list.clear()
                    gt_list.clear()

        if accelerator.is_local_main_process and epoch%save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt')