Commit a6950e04 authored by dzy7e's avatar dzy7e
Browse files

auc, map

parent 5bc99c52
Loading
Loading
Loading
Loading
+5 −11
Original line number Diff line number Diff line
@@ -148,18 +148,12 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        train_pos_total = 0
        pred_list, gt_list = [], []
        model.train()
        num_iter = len(train_dataloader)
        for i, (inputs, char_ids) in enumerate(tqdm(train_dataloader)):
            train_dataloader.dataset.reset()
            inputs = inputs.to(accelerator.device)  # BxCxHxW
            char_ids = char_ids.to(accelerator.device)  # B

            # B = len(char_ids)
            # mask = torch.triu(torch.ones(B,B),diagonal=1).to(accelerator.device)  # BxB, remove duplicated
            # similarities = model(inputs)  # BxB
            # outputs = similarities[mask]  # N
            # labels = (char_ids.view(-1,1) == char_ids.view(1,-1))[mask]  # N
            # labels = char_ids

            outputs = model(inputs)  # BxB
            labels = char_ids

@@ -182,7 +176,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                if (i+1)%loss_log_iter == 0:
                    mean_loss = running_loss/train_pos_total
                    if writer:
                        writer.add_scalar('train/loss', mean_loss, epoch)
                        writer.add_scalar('train/loss', mean_loss, epoch*num_iter + i)

                if (i+1)%log_iter == 0:
                    pred_t = torch.cat(pred_list).to(accelerator.device)
@@ -191,11 +185,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                    if accelerator.is_local_main_process:
                        auc = metric_auroc(pred_t, gt_t).item()
                        ap = metric_ap(pred_t, gt_t).item()
                        logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.')
                        logging.info(f'Epoch [{epoch}/{max_epochs}]<{i+1}/{num_iter}>, loss: {mean_loss:.6f}, AUC: {auc:.3e}, AP: {ap:.3e}.')
                        if writer:
                            #writer.add_scalar('train/loss', mean_loss, epoch)
                            writer.add_scalar('train/auc', auc, epoch)
                            writer.add_scalar('train/ap', auc, epoch)
                            writer.add_scalar('train/auc', auc, epoch*num_iter + i)
                            writer.add_scalar('train/ap', auc, epoch*num_iter + i)

                        pred_list.clear()
                        gt_list.clear()