Commit 5eb326d6 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add init ccip training code

parent 7f196444
Loading
Loading
Loading
Loading

zoo/ccip/__main__.py

0 → 100644
+16 −0
Original line number Diff line number Diff line
from functools import partial

import click

from ..utils import GLOBAL_CONTEXT_SETTINGS
from ..utils import print_version as _origin_print_version

print_version = partial(_origin_print_version, 'zoo.ccip')


@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('-v', '--version', is_flag=True,
              callback=print_version, expose_value=False, is_eager=True,
              help="Utils with pixiv resources.")
def cli():
    pass  # pragma: no cover
+93 −13
Original line number Diff line number Diff line
import glob
import os.path
from typing import List, Tuple
import random
from typing import List, Tuple, Dict

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from imgutils.data import load_image
from .prob import get_reg_for_prob

TRAIN_TRANSFORM = transforms.Compose([
    transforms.Resize(300),
    transforms.RandomCrop(250, padding=25, pad_if_needed=True, padding_mode='reflect'),
    transforms.RandomRotation((-15, 15)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.10, 0.10, 0.10, 0.10),
])
TEST_TRANSFORM = transforms.Compose([])

class CCIPDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        _ids, _maxid = {}, 0
        self.items: List[Tuple[str, int]] = []
        self.transform = transform

        for file in glob.glob(os.path.join(root_dir, '*', '*', '*.jpg')):
            dirname = os.path.normcase(os.path.normpath(os.path.dirname(os.path.abspath(file))))
            if dirname not in _ids:
                _ids[dirname] = _maxid
                _maxid += 1

            self.items.append((file, _ids[dirname]))
class ImagesDataset(Dataset):
    def __init__(self, items: List[Tuple[str, int]], transform=None):
        self.items: List[Tuple[str, int]] = items
        self.transform = transform

    def __len__(self):
        return len(self.items)
@@ -32,3 +36,79 @@ class CCIPDataset(Dataset):
            image = self.transform(image)

        return image, idx

    def split_dataset(self, test_prob: float = 0.2):
        total = len(self.items)
        test_ids = set(random.sample(list(range(total)), k=int(total * test_prob)))

        train_items, test_items = [], []
        for i, item in enumerate(self.items):
            if i in test_ids:
                test_items.append(item)
            else:
                train_items.append(item)

        return ImagesDataset(train_items, self.transform), ImagesDataset(test_items, self.transform)


class CCIPImagesDataset(ImagesDataset):
    def __init__(self, root_dir, transform=None):
        _ids, _maxid = {}, 0
        _items: List[Tuple[str, int]] = []
        for file in glob.glob(os.path.join(root_dir, '*', '*', '*.jpg')):
            dirname = os.path.normcase(os.path.normpath(os.path.dirname(os.path.abspath(file))))
            if dirname not in _ids:
                _ids[dirname] = _maxid
                _maxid += 1

            _items.append((file, _ids[dirname]))

        ImagesDataset.__init__(self, _items, transform)


class CharacterDataset(Dataset):
    def __init__(self, images_dataset: ImagesDataset, group_size: int = 100,
                 prob: float = 0.5, force_prob: bool = True):
        self.images_dataset = images_dataset
        self.groups: Dict[int, List[int]] = {}
        for i, (_, idx) in enumerate(self.images_dataset.items):
            if idx not in self.groups:
                self.groups[idx] = []
            self.groups[idx].append(i)

        self.group_size = group_size
        self._id_map = list(self.groups.keys())
        self._x_to_y, self._y_to_x = get_reg_for_prob(prob)
        self.force_prob = force_prob

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, item):
        idx = self._id_map[item]
        current_samples = self._x_to_y(self.group_size)
        if current_samples > len(self.groups[idx]) and self.force_prob:
            total_samples = self._y_to_x(len(self.groups[idx]))
            current_samples = self._x_to_y(total_samples)
            ex_samples = total_samples - current_samples
        else:
            ex_samples = self.group_size - current_samples

        indices = []
        indices.extend(random.sample(self.groups[idx], k=current_samples))
        for _ in range(ex_samples):
            while True:
                t_idx = random.choice(list(self.groups.keys()))
                if t_idx != idx:
                    break
            indices.append(random.choice(self.groups[t_idx]))

        random.shuffle(indices)
        images, labels = [], []
        for i in indices:
            image, label = self.images_dataset[i]
            images.append(image)
            labels.append(label)

        return torch.stack(list(map(torch.as_tensor, images))), \
               torch.stack(list(map(torch.as_tensor, labels)))
+3 −1
Original line number Diff line number Diff line
@@ -11,7 +11,9 @@ class FocalLoss(nn.Module):

    def __init__(self, weight=None, gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = torch.as_tensor(weight).float() if weight is not None else weight
        weight = torch.as_tensor(weight).float() if weight is not None else weight
        self.register_buffer('weight', weight)

        self.gamma = gamma
        self.reduction = reduction

zoo/ccip/prob.py

0 → 100644
+54 −0
Original line number Diff line number Diff line
from typing import Tuple, Callable

import numpy as np
from sklearn.linear_model import LinearRegression


def _array_create(n, m):
    return [0 if i < m else i - m + 1 for i in range(0, n)]


def _possibility(n, m):
    v = _array_create(n, m)
    same, not_same = 0, 0
    for i in range(n):
        for j in range(i, n):
            if v[i] == v[j]:
                same += 1
            else:
                not_same += 1

    return same / (not_same + same)


def _get_m_for_n(n, p=0.6):
    left, right = 0, n
    while left < right:
        middle = (left + right) // 2
        if _possibility(n, middle) < p:
            left = middle + 1
        else:
            right = middle

    return right


def get_reg_for_prob(prob=0.5) -> Tuple[Callable[[int], int], Callable[[int], int]]:
    x_arr = np.asarray(range(2, 150))
    y_arr = np.asarray([_get_m_for_n(i, p=prob) for i in x_arr])

    x_to_y = LinearRegression()
    x_to_y.fit(x_arr.reshape(-1, 1), y_arr)

    def _x_to_y(x: int) -> int:
        raw = x_to_y.predict(np.asarray([[x]]))[0].tolist()
        return int(round(raw))

    y_to_x = LinearRegression()
    y_to_x.fit(y_arr.reshape(-1, 1), x_arr)

    def _y_to_x(y: int) -> int:
        raw = y_to_x.predict(np.asarray([[y]]))[0].tolist()
        return int(round(raw))

    return _x_to_y, _y_to_x

zoo/ccip/train_.py

0 → 100644
+219 −0
Original line number Diff line number Diff line
import os
import re
from typing import Optional

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from ditk import logging
from hbutils.random import global_seed
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from tqdm.auto import tqdm

from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset
from .loss import FocalLoss
from .model import CCIP
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'ccip')
_LOG_DIR = os.path.join(_TRAIN_DIR, 'logs')
_CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts')

_CKPT_PATTERN = re.compile(r'^ccip-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$')


def _find_latest_ckpt(name: str) -> Optional[str]:
    if os.path.exists(_CKPT_DIR):
        ckpts = []
        for filename in os.listdir(_CKPT_DIR):
            matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename))
            if matching and matching.group('name') == name:
                ckpts.append((int(matching.group('epoch')), os.path.join(_CKPT_DIR, filename)))

        ckpts = sorted(ckpts)
        if ckpts:
            return ckpts[-1][1]
        else:
            return None
    else:
        return None


def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
    if filename is not None:
        matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename))
        if matching:
            return int(matching.group('epoch'))
        else:
            return None
    else:
        return None


def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 100,
          learning_rate: float = 0.001, weight_decay: float = 1e-3, preference: float = 0.0,
          save_per_epoch: int = 10, eval_epoch: int = 5,
          model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0):
    if seed is not None:
        # native random, numpy, torch and faker's seeds are includes
        # if you need to register more library for seeding, see:
        # https://hansbug.github.io/hbutils/main/api_doc/random/state.html#register-random-source
        global_seed(seed)

    accelerator = Accelerator(
        # mixed_precision=self.cfgs.mixed_precision,
        step_scheduler_with_optimizer=False,
    )

    session_name = session_name or re.sub(r'\W+', '-', model_name)
    _log_dir = os.path.join(_LOG_DIR, session_name)

    if accelerator.is_local_main_process:
        os.makedirs(_log_dir, exist_ok=True)
        os.makedirs(_CKPT_DIR, exist_ok=True)
        writer = SummaryWriter(_log_dir)
        writer.add_custom_scalars({
            "general": {
                "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]],
                "false": ["Multiline", ["test/fn", "test/fp", "train/fn", "train/fp"]],
            },
            "test": {
                "false": ["Multiline", ["test/fn", "test/fp"]],
            },
            "train": {
                "false": ["Multiline", ["train/fn", "train/fp"]],
            },
        })
    else:
        writer = None

    model = CCIP(model_name)
    preprocess = model.preprocess
    train_transform = Compose([*TRAIN_TRANSFORM.transforms, *preprocess.transforms])
    test_transform = preprocess

    image_dataset = CCIPImagesDataset(dataset_dir)
    train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio)
    train_image_dataset.transform = train_transform
    test_image_dataset.transform = test_transform

    train_dataset = CharacterDataset(train_image_dataset, group_size)
    test_dataset = CharacterDataset(test_image_dataset, group_size)

    if from_ckpt is None:
        from_ckpt = _find_latest_ckpt(session_name)
    previous_epoch = _ckpt_epoch(from_ckpt) or 0
    if from_ckpt:
        logging.info(f'Load checkpoint from {from_ckpt!r}.')
        model.load_state_dict(torch.load(from_ckpt, map_location='cpu'))
    else:
        logging.info(f'No checkpoint found, new model will be used.')

    if preference < 0:
        loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference
    else:
        loss_weight = torch.as_tensor([1.0, torch.e]) ** preference
    loss_fn = FocalLoss(weight=loss_weight).to(accelerator.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,
        steps_per_epoch=len(train_dataset), epochs=max_epochs,
        pct_start=0.15, final_div_factor=20.
    )

    model, optimizer, train_dataset, test_dataset, scheduler = \
        accelerator.prepare(model, optimizer, train_dataset, test_dataset, scheduler)

    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0
        train_correct, train_total = 0, 0
        train_fp, train_fn = 0, 0
        model.train()
        for i, (inputs, char_ids) in enumerate(tqdm(train_dataset)):
            inputs = inputs.float()
            inputs = inputs.to(accelerator.device)  # BxCxHxW
            char_ids = char_ids.to(accelerator.device)  # B

            ix = torch.arange(0, char_ids.shape[0])
            mask = ix >= ix.reshape(-1, 1)  # BxB, remove duplicated
            m_labels = (char_ids == char_ids.reshape(-1, 1))  # BxB

            features = model.backbone(inputs)  # BxF
            m_dists = F.pairwise_distance(features, features.unsqueeze(1))  # BxB
            dists = m_dists[mask].to(accelerator.device)
            labels = m_labels[mask].type(torch.long).to(accelerator.device)

            outputs = model.diff(dists.reshape(-1, 1))
            preds = torch.argmax(outputs, dim=1)
            train_correct += (preds == labels).sum().item()
            train_fp += (preds[labels == 0] == 1).sum().item()
            train_fn += (preds[labels == 1] == 0).sum().item()
            train_total += labels.shape[0]

            loss = loss_fn(outputs, labels)
            accelerator.backward(loss)
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            scheduler.step()

        epoch_loss = running_loss / train_total
        train_accuracy = train_correct / train_total
        train_fp_p = train_fp / train_total
        train_fn_p = train_fn / train_total

        if accelerator.is_local_main_process:
            logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, '
                         f'train accuracy: {train_accuracy:.4f}, '
                         f'false positive: {train_fp_p:.4f}, false negative: {train_fn_p:.4f}')
            if writer:
                writer.add_scalar('train/loss', epoch_loss, epoch)
                writer.add_scalar('train/accuracy', train_accuracy, epoch)
                writer.add_scalar('train/fp', train_fp_p, epoch)
                writer.add_scalar('train/fn', train_fn_p, epoch)

        model.eval()
        if epoch % eval_epoch == 0:
            with torch.no_grad():
                test_correct, test_total = 0, 0
                test_fp, test_fn = 0, 0

                for i, (inputs, char_ids) in enumerate(tqdm(test_dataset)):
                    inputs = inputs.float()
                    inputs = inputs.to(accelerator.device)  # BxCxHxW
                    char_ids = char_ids.to(accelerator.device)  # B

                    ix = torch.arange(0, char_ids.shape[0])
                    mask = ix >= ix.reshape(-1, 1)  # BxB, remove duplicated
                    m_labels = (char_ids == char_ids.reshape(-1, 1))  # BxB

                    features = model.backbone(inputs)  # BxF
                    m_dists = F.pairwise_distance(features, features.unsqueeze(1))  # BxB
                    dists = m_dists[mask].to(accelerator.device)
                    labels = m_labels[mask].type(torch.long).to(accelerator.device)

                    outputs = model.diff(dists.reshape(-1, 1))
                    preds = torch.argmax(outputs, dim=1)
                    train_correct += (preds == labels).sum().item()
                    train_fp += (preds[labels == 0] == 1).sum().item()
                    train_fn += (preds[labels == 1] == 0).sum().item()
                    train_total += labels.shape[0]

                test_accuracy = test_correct / test_total
                test_fp_p = test_fp / test_total
                test_fn_p = test_fn / test_total

                if accelerator.is_local_main_process:
                    logging.info(f'Epoch {epoch}, test accuracy: {test_accuracy:.4f}, '
                                 f'false positive: {test_fp_p:.4f}, false negative: {test_fn_p:.4f}')
                    if writer:
                        writer.add_scalar('test/accuracy', test_accuracy, epoch)
                        writer.add_scalar('test/fp', test_fp_p, epoch)
                        writer.add_scalar('test/fn', test_fn_p, epoch)

        if accelerator.is_local_main_process and epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt')
            torch.save(model.state_dict(), current_ckpt_file)
            logging.info(f'Saved to {current_ckpt_file!r}.')