Loading zoo/ccip/train_.py +64 −79 Original line number Diff line number Diff line Loading @@ -10,13 +10,14 @@ from hbutils.random import global_seed from sklearn import svm from sklearn.metrics import accuracy_score from torch.optim import lr_scheduler from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchmetrics import AUROC, AveragePrecision from torchvision.transforms import Compose from torch.utils.data import DataLoader from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, FastCharacterDataset, TEST_TRANSFORM, char_collect_fn from .loss import NTXentLoss, MLCELoss from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, FastCharacterDataset, TEST_TRANSFORM, char_collect_fn from .loss import MLCELoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading @@ -26,7 +27,6 @@ _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') _CKPT_PATTERN = re.compile(r'^ccip-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$') def _find_latest_ckpt(name: str) -> Optional[str]: if os.path.exists(_CKPT_DIR): ckpts = [] Loading @@ -43,7 +43,6 @@ def _find_latest_ckpt(name: str) -> Optional[str]: else: return None def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: if filename is not None: matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename)) Loading @@ -54,7 +53,6 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: else: return None def _sample_analysis(poss, negs, svm_samples: int = 10000): poss_cnt, negs_cnt = poss.shape[0], negs.shape[0] total = poss_cnt+negs_cnt Loading @@ -72,18 +70,12 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000): model.fit(features.reshape(-1, 1), labels) predictions = model.predict(features.reshape(-1, 1)) coef = model.coef_.reshape(-1)[0].tolist() inter = model.intercept_.reshape(-1)[0].tolist() threshold = -inter / coef return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), \ threshold, accuracy_score(labels, predictions) return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), accuracy_score(labels, predictions) def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30, learning_rate: float = 0.001, weight_decay: float = 1e-2, tau: float = 0.15, save_per_epoch: int = 10, eval_epoch: int = 5, num_workers=8, save_per_epoch: int = 10, eval_epoch: int = 5, log_iter: int = 500, num_workers=8, model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0): if seed is not None: # native random, numpy, torch and faker's seeds are includes Loading Loading @@ -148,10 +140,13 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio model, optimizer, train_dataloader, test_dataloader, scheduler = \ accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) metric_auroc = AUROC(task="binary") metric_ap = AveragePrecision(task="binary") for epoch in range(previous_epoch+1, max_epochs+1): running_loss = 0.0 train_pos_total = 0 positive_sims, negative_sims = [], [] pred_list, gt_list = [], [] model.train() for i, (inputs, char_ids) in enumerate(tqdm(train_dataloader)): train_dataloader.dataset.reset() Loading @@ -174,72 +169,62 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio scheduler.step() running_loss += loss.item()*len(char_ids) train_pos_total += len(char_ids) gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() mask = torch.ones_like(outputs) mask -= torch.diag_embed(torch.diag(mask)) outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) # outputs.diagonal().copy_(torch.ones(len(char_ids))*-10000) # max_idxs = outputs.argsort(dim=-1) # for max_idx, n_pos in zip(max_idxs, gt.sum(dim=1)): # train_pos_total += n_pos # positive_sims.append(outputs[labels]) # negative_sims.append(outputs[~labels]) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) epoch_loss = running_loss #/ train_pos_total train_psims = torch.cat(positive_sims) train_nsims = torch.cat(negative_sims) train_pos_mean, train_pos_std, train_neg_mean, train_neg_std, train_threshold, train_acc_svm = \ _sample_analysis(train_psims, train_nsims) gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() pred_list.append(outputs[mask]) gt_list.append(gt_same.long()[mask]) if (i+1)%log_iter == 0: pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list))) if accelerator.is_local_main_process: logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, ' f'acc_svm: {train_acc_svm:.6f}, threshold: {train_threshold:.6f}.') auc = metric_auroc(pred_t, gt_t).item() ap = metric_ap(pred_t, gt_t).item() mean_loss = running_loss/train_pos_total logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.') if writer: writer.add_scalar('train/loss', epoch_loss, epoch) writer.add_scalar('train/pos/mean', train_pos_mean, epoch) writer.add_scalar('train/pos/std', train_pos_std, epoch) writer.add_scalar('train/neg/mean', train_neg_mean, epoch) writer.add_scalar('train/neg/std', train_neg_std, epoch) writer.add_scalar('train/threshold', train_threshold, epoch) writer.add_scalar('train/acc_svm', train_acc_svm, epoch) writer.add_scalar('train/loss', mean_loss, epoch) writer.add_scalar('train/auc', auc, epoch) writer.add_scalar('train/ap', auc, epoch) pred_list.clear() gt_list.clear() running_loss = 0. train_pos_total = 0 model.eval() if epoch%eval_epoch == 0: with torch.no_grad(): positive_sims, negative_sims = [], [] pred_list, gt_list = [], [] for i, (inputs, char_ids) in enumerate(tqdm(test_dataloader)): inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B outputs = model(inputs) # BxB gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() mask = torch.ones_like(outputs) mask -= torch.diag_embed(torch.diag(mask)) outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) test_psims = torch.cat(positive_sims) test_nsims = torch.cat(negative_sims) test_pos_mean, test_pos_std, test_neg_mean, test_neg_std, test_threshold, test_acc_svm = \ _sample_analysis(test_psims, test_nsims) gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() pred_list.append(outputs[mask]) gt_list.append(gt_same.long()[mask]) pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list))) if accelerator.is_local_main_process: logging.info(f'Epoch {epoch}, ' f'acc_svm: {test_acc_svm:.6f}, threshold: {test_threshold:.6f}') auc = metric_auroc(pred_t, gt_t).item() ap = metric_ap(pred_t, gt_t).item() mean_loss = running_loss/train_pos_total logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.') if writer: writer.add_scalar('test/pos/mean', test_pos_mean, epoch) writer.add_scalar('test/pos/std', test_pos_std, epoch) writer.add_scalar('test/neg/mean', test_neg_mean, epoch) writer.add_scalar('test/neg/std', test_neg_std, epoch) writer.add_scalar('test/threshold', test_threshold, epoch) writer.add_scalar('test/acc_svm', test_acc_svm, epoch) writer.add_scalar('test/loss', mean_loss, epoch) writer.add_scalar('test/auc', auc, epoch) writer.add_scalar('test/ap', auc, epoch) pred_list.clear() gt_list.clear() if accelerator.is_local_main_process and epoch%save_per_epoch == 0: current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt') Loading Loading
zoo/ccip/train_.py +64 −79 Original line number Diff line number Diff line Loading @@ -10,13 +10,14 @@ from hbutils.random import global_seed from sklearn import svm from sklearn.metrics import accuracy_score from torch.optim import lr_scheduler from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchmetrics import AUROC, AveragePrecision from torchvision.transforms import Compose from torch.utils.data import DataLoader from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, FastCharacterDataset, TEST_TRANSFORM, char_collect_fn from .loss import NTXentLoss, MLCELoss from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, FastCharacterDataset, TEST_TRANSFORM, char_collect_fn from .loss import MLCELoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading @@ -26,7 +27,6 @@ _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') _CKPT_PATTERN = re.compile(r'^ccip-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$') def _find_latest_ckpt(name: str) -> Optional[str]: if os.path.exists(_CKPT_DIR): ckpts = [] Loading @@ -43,7 +43,6 @@ def _find_latest_ckpt(name: str) -> Optional[str]: else: return None def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: if filename is not None: matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename)) Loading @@ -54,7 +53,6 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: else: return None def _sample_analysis(poss, negs, svm_samples: int = 10000): poss_cnt, negs_cnt = poss.shape[0], negs.shape[0] total = poss_cnt+negs_cnt Loading @@ -72,18 +70,12 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000): model.fit(features.reshape(-1, 1), labels) predictions = model.predict(features.reshape(-1, 1)) coef = model.coef_.reshape(-1)[0].tolist() inter = model.intercept_.reshape(-1)[0].tolist() threshold = -inter / coef return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), \ threshold, accuracy_score(labels, predictions) return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), accuracy_score(labels, predictions) def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30, learning_rate: float = 0.001, weight_decay: float = 1e-2, tau: float = 0.15, save_per_epoch: int = 10, eval_epoch: int = 5, num_workers=8, save_per_epoch: int = 10, eval_epoch: int = 5, log_iter: int = 500, num_workers=8, model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0): if seed is not None: # native random, numpy, torch and faker's seeds are includes Loading Loading @@ -148,10 +140,13 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio model, optimizer, train_dataloader, test_dataloader, scheduler = \ accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) metric_auroc = AUROC(task="binary") metric_ap = AveragePrecision(task="binary") for epoch in range(previous_epoch+1, max_epochs+1): running_loss = 0.0 train_pos_total = 0 positive_sims, negative_sims = [], [] pred_list, gt_list = [], [] model.train() for i, (inputs, char_ids) in enumerate(tqdm(train_dataloader)): train_dataloader.dataset.reset() Loading @@ -174,72 +169,62 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio scheduler.step() running_loss += loss.item()*len(char_ids) train_pos_total += len(char_ids) gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() mask = torch.ones_like(outputs) mask -= torch.diag_embed(torch.diag(mask)) outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) # outputs.diagonal().copy_(torch.ones(len(char_ids))*-10000) # max_idxs = outputs.argsort(dim=-1) # for max_idx, n_pos in zip(max_idxs, gt.sum(dim=1)): # train_pos_total += n_pos # positive_sims.append(outputs[labels]) # negative_sims.append(outputs[~labels]) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) epoch_loss = running_loss #/ train_pos_total train_psims = torch.cat(positive_sims) train_nsims = torch.cat(negative_sims) train_pos_mean, train_pos_std, train_neg_mean, train_neg_std, train_threshold, train_acc_svm = \ _sample_analysis(train_psims, train_nsims) gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() pred_list.append(outputs[mask]) gt_list.append(gt_same.long()[mask]) if (i+1)%log_iter == 0: pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list))) if accelerator.is_local_main_process: logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, ' f'acc_svm: {train_acc_svm:.6f}, threshold: {train_threshold:.6f}.') auc = metric_auroc(pred_t, gt_t).item() ap = metric_ap(pred_t, gt_t).item() mean_loss = running_loss/train_pos_total logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.') if writer: writer.add_scalar('train/loss', epoch_loss, epoch) writer.add_scalar('train/pos/mean', train_pos_mean, epoch) writer.add_scalar('train/pos/std', train_pos_std, epoch) writer.add_scalar('train/neg/mean', train_neg_mean, epoch) writer.add_scalar('train/neg/std', train_neg_std, epoch) writer.add_scalar('train/threshold', train_threshold, epoch) writer.add_scalar('train/acc_svm', train_acc_svm, epoch) writer.add_scalar('train/loss', mean_loss, epoch) writer.add_scalar('train/auc', auc, epoch) writer.add_scalar('train/ap', auc, epoch) pred_list.clear() gt_list.clear() running_loss = 0. train_pos_total = 0 model.eval() if epoch%eval_epoch == 0: with torch.no_grad(): positive_sims, negative_sims = [], [] pred_list, gt_list = [], [] for i, (inputs, char_ids) in enumerate(tqdm(test_dataloader)): inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B outputs = model(inputs) # BxB gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() mask = torch.ones_like(outputs) mask -= torch.diag_embed(torch.diag(mask)) outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) test_psims = torch.cat(positive_sims) test_nsims = torch.cat(negative_sims) test_pos_mean, test_pos_std, test_neg_mean, test_neg_std, test_threshold, test_acc_svm = \ _sample_analysis(test_psims, test_nsims) gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() pred_list.append(outputs[mask]) gt_list.append(gt_same.long()[mask]) pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list))) if accelerator.is_local_main_process: logging.info(f'Epoch {epoch}, ' f'acc_svm: {test_acc_svm:.6f}, threshold: {test_threshold:.6f}') auc = metric_auroc(pred_t, gt_t).item() ap = metric_ap(pred_t, gt_t).item() mean_loss = running_loss/train_pos_total logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.') if writer: writer.add_scalar('test/pos/mean', test_pos_mean, epoch) writer.add_scalar('test/pos/std', test_pos_std, epoch) writer.add_scalar('test/neg/mean', test_neg_mean, epoch) writer.add_scalar('test/neg/std', test_neg_std, epoch) writer.add_scalar('test/threshold', test_threshold, epoch) writer.add_scalar('test/acc_svm', test_acc_svm, epoch) writer.add_scalar('test/loss', mean_loss, epoch) writer.add_scalar('test/auc', auc, epoch) writer.add_scalar('test/ap', auc, epoch) pred_list.clear() gt_list.clear() if accelerator.is_local_main_process and epoch%save_per_epoch == 0: current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt') Loading