Commit 736333a1 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use new metrics and loss

parent 4e0a4931
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -33,9 +33,10 @@ class NTXentLoss(nn.Module):
    Inspired from https://blog.csdn.net/cziun/article/details/119118768 .
    """

    def __init__(self, tau: float = 1.0):
    def __init__(self, tau: float = 1.0, eps: float = 1e-8):
        nn.Module.__init__(self)
        self.register_buffer('tau', torch.as_tensor(tau, dtype=torch.float))
        self.register_buffer('eps', torch.as_tensor(eps, dtype=torch.float))

    def forward(self, sim_tensors, state_tensors):
        """
@@ -44,4 +45,4 @@ class NTXentLoss(nn.Module):
        """
        log_items = -torch.log(torch.softmax(sim_tensors / self.tau, dim=-1))
        positive_items = log_items[state_tensors]
        return positive_items.mean()
        return (positive_items.sum() + self.eps) / (positive_items.shape[0] + self.eps)
+26 −30
Original line number Diff line number Diff line
@@ -63,18 +63,20 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):
    else:
        s_poss, s_negs = poss, negs

    features = torch.cat([s_poss, s_negs])
    labels = torch.cat([torch.ones_like(s_poss), torch.zeros_like(s_negs)])
    s_poss, s_negs = s_poss.cpu(), s_negs.cpu()
    features = torch.cat([s_poss, s_negs]).detach().numpy()
    labels = torch.cat([torch.ones_like(s_poss), torch.zeros_like(s_negs)]).detach().numpy()

    model = svm.SVC(kernel='linear')  # 线性核
    model.fit(features.reshape(-1, 1), labels)
    predictions = model.predict(features.reshape(-1, 1))

    coef = model.coef_.reshape(-1)[0].item()
    inter = model.intercept_.reshape(-1)[0].item()
    coef = model.coef_.reshape(-1)[0].tolist()
    inter = model.intercept_.reshape(-1)[0].tolist()
    threshold = -coef / inter

    return poss.mean(), poss.std(), negs.mean(), negs.std(), threshold, accuracy_score(labels, predictions)
    return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), \
           threshold, accuracy_score(labels, predictions)


def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
@@ -101,15 +103,9 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        os.makedirs(_CKPT_DIR, exist_ok=True)
        writer = SummaryWriter(_log_dir)
        writer.add_custom_scalars({
            "general": {
                "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]],
                "false": ["Multiline", ["test/fn", "test/fp", "train/fn", "train/fp"]],
            },
            "test": {
                "false": ["Multiline", ["test/fn", "test/fp"]],
            },
            "train": {
                "false": ["Multiline", ["train/fn", "train/fp"]],
            "contrastive": {
                "train": ["Multiline", ["train/pos/mean", "train/neg/mean"]],
                "test": ["Multiline", ["test/pos/mean", "test/neg/mean"]],
            },
        })
    else:
@@ -173,20 +169,20 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        epoch_loss = running_loss / train_pos_total
        train_psims = torch.cat(positive_sims)
        train_nsims = torch.cat(negative_sims)
        train_psim_mean, train_psim_std, train_msim_mean, train_msim_std, train_threshold, train_acc_svm = \
        train_pos_mean, train_pos_std, train_neg_mean, train_neg_std, train_threshold, train_acc_svm = \
            _sample_analysis(train_psims, train_nsims)

        if accelerator.is_local_main_process:
            logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, '
                         f'acc_scm: {train_acc_svm:.6f}, threshold: {train_threshold:.6f}.')
                         f'acc_svm: {train_acc_svm:.6f}, threshold: {train_threshold:.6f}.')
            if writer:
                writer.add_scalar('train/loss', epoch_loss, epoch)
                writer.add_scalar('train/psim/mean', train_psim_mean, epoch)
                writer.add_scalar('train/psim/std', train_psim_std, epoch)
                writer.add_scalar('train/msim/mean', train_msim_mean, epoch)
                writer.add_scalar('train/msim/std', train_msim_std, epoch)
                writer.add_scalar('train/threshold', train_threshold)
                writer.add_scalar('train/acc_svm', train_acc_svm)
                writer.add_scalar('train/pos/mean', train_pos_mean, epoch)
                writer.add_scalar('train/pos/std', train_pos_std, epoch)
                writer.add_scalar('train/neg/mean', train_neg_mean, epoch)
                writer.add_scalar('train/neg/std', train_neg_std, epoch)
                writer.add_scalar('train/threshold', train_threshold, epoch)
                writer.add_scalar('train/acc_svm', train_acc_svm, epoch)

        model.eval()
        if epoch % eval_epoch == 0:
@@ -208,19 +204,19 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

                test_psims = torch.cat(positive_sims)
                test_nsims = torch.cat(negative_sims)
                test_psim_mean, test_psim_std, test_msim_mean, test_msim_std, test_threshold, test_acc_svm = \
                test_pos_mean, test_pos_std, test_neg_mean, test_neg_std, test_threshold, test_acc_svm = \
                    _sample_analysis(test_psims, test_nsims)

                if accelerator.is_local_main_process:
                    logging.info(f'Epoch {epoch}, '
                                 f'acc_scm: {test_acc_svm:.6f}, threshold: {test_threshold:.6f}')
                                 f'acc_svm: {test_acc_svm:.6f}, threshold: {test_threshold:.6f}')
                    if writer:
                        writer.add_scalar('test/psim/mean', test_psim_mean, epoch)
                        writer.add_scalar('test/psim/std', test_psim_std, epoch)
                        writer.add_scalar('test/msim/mean', test_msim_mean, epoch)
                        writer.add_scalar('test/msim/std', test_msim_std, epoch)
                        writer.add_scalar('test/threshold', test_threshold)
                        writer.add_scalar('test/acc_svm', test_acc_svm)
                        writer.add_scalar('test/pos/mean', test_pos_mean, epoch)
                        writer.add_scalar('test/pos/std', test_pos_std, epoch)
                        writer.add_scalar('test/neg/mean', test_neg_mean, epoch)
                        writer.add_scalar('test/neg/std', test_neg_std, epoch)
                        writer.add_scalar('test/threshold', test_threshold, epoch)
                        writer.add_scalar('test/acc_svm', test_acc_svm, epoch)

        if accelerator.is_local_main_process and epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt')