Commit 5bc99c52 authored by dzy7e's avatar dzy7e
Browse files

auc, map

parent 1fb954ee
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ torch<2
lpips
matplotlib
torchvision
torchmetrics>=0.11.4
tqdm
onnx
onnxoptimizer
+28 −21
Original line number Diff line number Diff line
@@ -75,7 +75,7 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):
def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30,
          learning_rate: float = 0.001, weight_decay: float = 1e-2, tau: float = 0.15,
          save_per_epoch: int = 10, eval_epoch: int = 5, log_iter: int = 500, num_workers=8,
          save_per_epoch: int = 10, eval_epoch: int = 5, loss_log_iter:int = 20, log_iter: int = 500, num_workers=8,
          model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0):
    if seed is not None:
        # native random, numpy, torch and faker's seeds are includes
@@ -171,22 +171,29 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            running_loss += loss.item()*len(char_ids)
            train_pos_total += len(char_ids)

            mask = torch.ones_like(outputs)
            mask -= torch.diag_embed(torch.diag(mask))
            mask = torch.ones_like(outputs).bool()
            mask ^= torch.diag_embed(torch.diag(mask))
            outputs = outputs.detach().cpu()
            gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
            pred_list.append(outputs[mask])
            gt_list.append(gt_same.long()[mask])

            with torch.no_grad():
                if (i+1)%loss_log_iter == 0:
                    mean_loss = running_loss/train_pos_total
                    if writer:
                        writer.add_scalar('train/loss', mean_loss, epoch)

                if (i+1)%log_iter == 0:
                pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list)))
                    pred_t = torch.cat(pred_list).to(accelerator.device)
                    gt_t = torch.cat(gt_list).to(accelerator.device)
                    pred_t, gt_t = accelerator.gather_for_metrics((pred_t, gt_t))
                    if accelerator.is_local_main_process:
                        auc = metric_auroc(pred_t, gt_t).item()
                        ap = metric_ap(pred_t, gt_t).item()
                    mean_loss = running_loss/train_pos_total
                        logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.')
                        if writer:
                        writer.add_scalar('train/loss', mean_loss, epoch)
                            #writer.add_scalar('train/loss', mean_loss, epoch)
                            writer.add_scalar('train/auc', auc, epoch)
                            writer.add_scalar('train/ap', auc, epoch)

@@ -205,21 +212,21 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

                    outputs = model(inputs)  # BxB

                    mask = torch.ones_like(outputs)
                    mask -= torch.diag_embed(torch.diag(mask))
                    mask = torch.ones_like(outputs).bool()
                    mask ^= torch.diag_embed(torch.diag(mask))
                    outputs = outputs.detach().cpu()
                    gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
                    pred_list.append(outputs[mask])
                    gt_list.append(gt_same.long()[mask])

                pred_t, gt_t = accelerator.gather_for_metrics((torch.cat(pred_list), torch.cat(gt_list)))
                pred_t = torch.cat(pred_list).to(accelerator.device)
                gt_t = torch.cat(gt_list).to(accelerator.device)
                pred_t, gt_t = accelerator.gather_for_metrics((pred_t, gt_t))
                if accelerator.is_local_main_process:
                    auc = metric_auroc(pred_t, gt_t).item()
                    ap = metric_ap(pred_t, gt_t).item()
                    mean_loss = running_loss/train_pos_total
                    logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.')
                    logging.info(f'Epoch [{epoch}/{max_epochs}], AUC: {auc:.3e}, AP: {ap:.3e}.')
                    if writer:
                        writer.add_scalar('test/loss', mean_loss, epoch)
                        writer.add_scalar('test/auc', auc, epoch)
                        writer.add_scalar('test/ap', auc, epoch)