Commit 4e0a4931 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add new sampling method

parent 3b099aa7
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -26,3 +26,22 @@ class FocalLoss(nn.Module):
            weight=self.weight,
            reduction=self.reduction
        )


class NTXentLoss(nn.Module):
    """
    Inspired from https://blog.csdn.net/cziun/article/details/119118768 .
    """

    def __init__(self, tau: float = 1.0):
        nn.Module.__init__(self)
        self.register_buffer('tau', torch.as_tensor(tau, dtype=torch.float))

    def forward(self, sim_tensors, state_tensors):
        """
        :param sim_tensors: Similarities, float32[N]
        :param state_tensors: Positive sample or not, bool[N]
        """
        log_items = -torch.log(torch.softmax(sim_tensors / self.tau, dim=-1))
        positive_items = log_items[state_tensors]
        return positive_items.mean()
+2 −4
Original line number Diff line number Diff line
@@ -11,11 +11,9 @@ class CCIPBatchMetrics(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.sim = nn.CosineSimilarity(dim=-1)
        self.fc = nn.Linear(1, 2)

    def forward(self, x):  # x: BxN
        x = self.sim(x, x.unsqueeze(1))
        x = self.fc(x.unsqueeze(-1))
        x = self.sim(x, x.unsqueeze(1))  # BxB
        return x


@@ -43,7 +41,7 @@ class CCIP(torch.nn.Module):
    def forward(self, x):
        # x: BxCxHxW
        x = self.feature(x)  # BxF
        x = self.metrics(x)  # BxBx2
        x = self.metrics(x)  # BxB
        return x


+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ def _possibility(n, m):
    v = _array_create(n, m)
    same, not_same = 0, 0
    for i in range(n):
        for j in range(i, n):
        for j in range(i + 1, n):
            if v[i] == v[j]:
                same += 1
            else:
+73 −53
Original line number Diff line number Diff line
import os
import random
import re
from typing import Optional

@@ -6,13 +7,15 @@ import torch
from accelerate import Accelerator
from ditk import logging
from hbutils.random import global_seed
from sklearn import svm
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from tqdm.auto import tqdm

from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM
from .loss import FocalLoss
from .loss import NTXentLoss
from .model import CCIP
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

@@ -51,9 +54,32 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
        return None


def _sample_analysis(poss, negs, svm_samples: int = 10000):
    poss_cnt, negs_cnt = poss.shape[0], negs.shape[0]
    total = poss_cnt + negs_cnt
    if total > svm_samples:
        s_poss = poss[random.sample(range(poss_cnt), k=int(round(poss_cnt * svm_samples / total)))]
        s_negs = negs[random.sample(range(negs_cnt), k=int(round(negs_cnt * svm_samples / total)))]
    else:
        s_poss, s_negs = poss, negs

    features = torch.cat([s_poss, s_negs])
    labels = torch.cat([torch.ones_like(s_poss), torch.zeros_like(s_negs)])

    model = svm.SVC(kernel='linear')  # 线性核
    model.fit(features.reshape(-1, 1), labels)
    predictions = model.predict(features.reshape(-1, 1))

    coef = model.coef_.reshape(-1)[0].item()
    inter = model.intercept_.reshape(-1)[0].item()
    threshold = -coef / inter

    return poss.mean(), poss.std(), negs.mean(), negs.std(), threshold, accuracy_score(labels, predictions)


def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 100,
          learning_rate: float = 0.001, weight_decay: float = 1e-3, preference: float = 0.0,
          learning_rate: float = 0.001, weight_decay: float = 1e-3,
          save_per_epoch: int = 10, eval_epoch: int = 5,
          model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0):
    if seed is not None:
@@ -107,11 +133,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    else:
        logging.info(f'No checkpoint found, new model will be used.')

    if preference < 0:
        loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference
    else:
        loss_weight = torch.as_tensor([1.0, torch.e]) ** preference
    loss_fn = FocalLoss(weight=loss_weight).to(accelerator.device)
    loss_fn = NTXentLoss().to(accelerator.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,
@@ -124,8 +146,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0
        train_correct, train_total = 0, 0
        train_fp, train_fn = 0, 0
        train_pos_total = 0
        positive_sims, negative_sims = [], []
        model.train()
        for i, (inputs, char_ids) in enumerate(tqdm(train_dataset)):
            inputs = inputs.float()
@@ -133,74 +155,72 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            char_ids = char_ids.to(accelerator.device)  # B

            ix = torch.arange(0, char_ids.shape[0])
            mask = ix >= ix.reshape(-1, 1)  # BxB, remove duplicated
            logits = model(inputs)  # BxBx2
            outputs = logits[mask]  # Nx2
            labels = (char_ids == char_ids.reshape(-1, 1))[mask]  # N
            labels = labels.type(torch.long).to(accelerator.device)  # N

            preds = torch.argmax(outputs, dim=1)
            train_correct += (preds == labels).sum().item()
            train_fp += (preds[labels == 0] == 1).sum().item()
            train_fn += (preds[labels == 1] == 0).sum().item()
            train_total += labels.shape[0]
            mask = ix > ix.reshape(-1, 1)  # BxB, remove duplicated
            similarities = model(inputs)  # BxB
            outputs = similarities[mask]  # N
            labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device)  # N

            loss = loss_fn(outputs, labels)
            accelerator.backward(loss)
            optimizer.step()
            running_loss += loss.item() * labels.shape[0]
            train_pos_total += labels.sum()
            running_loss += loss.item() * labels.sum()
            scheduler.step()

        epoch_loss = running_loss / train_total
        train_accuracy = train_correct / train_total
        train_fp_p = train_fp / train_total
        train_fn_p = train_fn / train_total
            positive_sims.append(outputs[labels])
            negative_sims.append(outputs[~labels])

        epoch_loss = running_loss / train_pos_total
        train_psims = torch.cat(positive_sims)
        train_nsims = torch.cat(negative_sims)
        train_psim_mean, train_psim_std, train_msim_mean, train_msim_std, train_threshold, train_acc_svm = \
            _sample_analysis(train_psims, train_nsims)

        if accelerator.is_local_main_process:
            logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, '
                         f'train accuracy: {train_accuracy:.4f}, '
                         f'false positive: {train_fp_p:.4f}, false negative: {train_fn_p:.4f}')
                         f'acc_scm: {train_acc_svm:.6f}, threshold: {train_threshold:.6f}.')
            if writer:
                writer.add_scalar('train/loss', epoch_loss, epoch)
                writer.add_scalar('train/accuracy', train_accuracy, epoch)
                writer.add_scalar('train/fp', train_fp_p, epoch)
                writer.add_scalar('train/fn', train_fn_p, epoch)
                writer.add_scalar('train/psim/mean', train_psim_mean, epoch)
                writer.add_scalar('train/psim/std', train_psim_std, epoch)
                writer.add_scalar('train/msim/mean', train_msim_mean, epoch)
                writer.add_scalar('train/msim/std', train_msim_std, epoch)
                writer.add_scalar('train/threshold', train_threshold)
                writer.add_scalar('train/acc_svm', train_acc_svm)

        model.eval()
        if epoch % eval_epoch == 0:
            with torch.no_grad():
                test_correct, test_total = 0, 0
                test_fp, test_fn = 0, 0

                positive_sims, negative_sims = [], []
                for i, (inputs, char_ids) in enumerate(tqdm(test_dataset)):
                    inputs = inputs.float()
                    inputs = inputs.to(accelerator.device)  # BxCxHxW
                    char_ids = char_ids.to(accelerator.device)  # B

                    ix = torch.arange(0, char_ids.shape[0])
                    mask = ix >= ix.reshape(-1, 1)  # BxB, remove duplicated
                    logits = model(inputs)  # BxBx2
                    outputs = logits[mask]  # Nx2
                    labels = (char_ids == char_ids.reshape(-1, 1))[mask]  # N
                    labels = labels.type(torch.long).to(accelerator.device)  # N

                    preds = torch.argmax(outputs, dim=1)
                    test_correct += (preds == labels).sum().item()
                    test_fp += (preds[labels == 0] == 1).sum().item()
                    test_fn += (preds[labels == 1] == 0).sum().item()
                    test_total += labels.shape[0]

                test_accuracy = test_correct / test_total
                test_fp_p = test_fp / test_total
                test_fn_p = test_fn / test_total
                    mask = ix > ix.reshape(-1, 1)  # BxB, remove duplicated
                    similarities = model(inputs)  # BxB
                    outputs = similarities[mask]  # N
                    labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device)  # N

                    positive_sims.append(outputs[labels])
                    negative_sims.append(outputs[~labels])

                test_psims = torch.cat(positive_sims)
                test_nsims = torch.cat(negative_sims)
                test_psim_mean, test_psim_std, test_msim_mean, test_msim_std, test_threshold, test_acc_svm = \
                    _sample_analysis(test_psims, test_nsims)

                if accelerator.is_local_main_process:
                    logging.info(f'Epoch {epoch}, test accuracy: {test_accuracy:.4f}, '
                                 f'false positive: {test_fp_p:.4f}, false negative: {test_fn_p:.4f}')
                    logging.info(f'Epoch {epoch}, '
                                 f'acc_scm: {test_acc_svm:.6f}, threshold: {test_threshold:.6f}')
                    if writer:
                        writer.add_scalar('test/accuracy', test_accuracy, epoch)
                        writer.add_scalar('test/fp', test_fp_p, epoch)
                        writer.add_scalar('test/fn', test_fn_p, epoch)
                        writer.add_scalar('test/psim/mean', test_psim_mean, epoch)
                        writer.add_scalar('test/psim/std', test_psim_std, epoch)
                        writer.add_scalar('test/msim/mean', test_msim_mean, epoch)
                        writer.add_scalar('test/msim/std', test_msim_std, epoch)
                        writer.add_scalar('test/threshold', test_threshold)
                        writer.add_scalar('test/acc_svm', test_acc_svm)

        if accelerator.is_local_main_process and epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt')