Loading zoo/ccip/__main__.py 0 → 100644 +16 −0 Original line number Diff line number Diff line from functools import partial import click from ..utils import GLOBAL_CONTEXT_SETTINGS from ..utils import print_version as _origin_print_version print_version = partial(_origin_print_version, 'zoo.ccip') @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}) @click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True, help="Utils with pixiv resources.") def cli(): pass # pragma: no cover zoo/ccip/dataset.py +93 −13 Original line number Diff line number Diff line import glob import os.path from typing import List, Tuple import random from typing import List, Tuple, Dict import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from imgutils.data import load_image from .prob import get_reg_for_prob TRAIN_TRANSFORM = transforms.Compose([ transforms.Resize(300), transforms.RandomCrop(250, padding=25, pad_if_needed=True, padding_mode='reflect'), transforms.RandomRotation((-15, 15)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.10, 0.10, 0.10, 0.10), ]) TEST_TRANSFORM = transforms.Compose([]) class CCIPDataset(Dataset): def __init__(self, root_dir, transform=None): _ids, _maxid = {}, 0 self.items: List[Tuple[str, int]] = [] self.transform = transform for file in glob.glob(os.path.join(root_dir, '*', '*', '*.jpg')): dirname = os.path.normcase(os.path.normpath(os.path.dirname(os.path.abspath(file)))) if dirname not in _ids: _ids[dirname] = _maxid _maxid += 1 self.items.append((file, _ids[dirname])) class ImagesDataset(Dataset): def __init__(self, items: List[Tuple[str, int]], transform=None): self.items: List[Tuple[str, int]] = items self.transform = transform def __len__(self): return len(self.items) Loading @@ -32,3 +36,79 @@ class CCIPDataset(Dataset): image = self.transform(image) return image, idx def split_dataset(self, test_prob: float = 0.2): total = len(self.items) test_ids = set(random.sample(list(range(total)), k=int(total * test_prob))) train_items, test_items = [], [] for i, item in enumerate(self.items): if i in test_ids: test_items.append(item) else: train_items.append(item) return ImagesDataset(train_items, self.transform), ImagesDataset(test_items, self.transform) class CCIPImagesDataset(ImagesDataset): def __init__(self, root_dir, transform=None): _ids, _maxid = {}, 0 _items: List[Tuple[str, int]] = [] for file in glob.glob(os.path.join(root_dir, '*', '*', '*.jpg')): dirname = os.path.normcase(os.path.normpath(os.path.dirname(os.path.abspath(file)))) if dirname not in _ids: _ids[dirname] = _maxid _maxid += 1 _items.append((file, _ids[dirname])) ImagesDataset.__init__(self, _items, transform) class CharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, force_prob: bool = True): self.images_dataset = images_dataset self.groups: Dict[int, List[int]] = {} for i, (_, idx) in enumerate(self.images_dataset.items): if idx not in self.groups: self.groups[idx] = [] self.groups[idx].append(i) self.group_size = group_size self._id_map = list(self.groups.keys()) self._x_to_y, self._y_to_x = get_reg_for_prob(prob) self.force_prob = force_prob def __len__(self): return len(self.groups) def __getitem__(self, item): idx = self._id_map[item] current_samples = self._x_to_y(self.group_size) if current_samples > len(self.groups[idx]) and self.force_prob: total_samples = self._y_to_x(len(self.groups[idx])) current_samples = self._x_to_y(total_samples) ex_samples = total_samples - current_samples else: ex_samples = self.group_size - current_samples indices = [] indices.extend(random.sample(self.groups[idx], k=current_samples)) for _ in range(ex_samples): while True: t_idx = random.choice(list(self.groups.keys())) if t_idx != idx: break indices.append(random.choice(self.groups[t_idx])) random.shuffle(indices) images, labels = [], [] for i in indices: image, label = self.images_dataset[i] images.append(image) labels.append(label) return torch.stack(list(map(torch.as_tensor, images))), \ torch.stack(list(map(torch.as_tensor, labels))) zoo/ccip/loss.py +3 −1 Original line number Diff line number Diff line Loading @@ -11,7 +11,9 @@ class FocalLoss(nn.Module): def __init__(self, weight=None, gamma=2., reduction='mean'): nn.Module.__init__(self) self.weight = torch.as_tensor(weight).float() if weight is not None else weight weight = torch.as_tensor(weight).float() if weight is not None else weight self.register_buffer('weight', weight) self.gamma = gamma self.reduction = reduction Loading zoo/ccip/prob.py 0 → 100644 +54 −0 Original line number Diff line number Diff line from typing import Tuple, Callable import numpy as np from sklearn.linear_model import LinearRegression def _array_create(n, m): return [0 if i < m else i - m + 1 for i in range(0, n)] def _possibility(n, m): v = _array_create(n, m) same, not_same = 0, 0 for i in range(n): for j in range(i, n): if v[i] == v[j]: same += 1 else: not_same += 1 return same / (not_same + same) def _get_m_for_n(n, p=0.6): left, right = 0, n while left < right: middle = (left + right) // 2 if _possibility(n, middle) < p: left = middle + 1 else: right = middle return right def get_reg_for_prob(prob=0.5) -> Tuple[Callable[[int], int], Callable[[int], int]]: x_arr = np.asarray(range(2, 150)) y_arr = np.asarray([_get_m_for_n(i, p=prob) for i in x_arr]) x_to_y = LinearRegression() x_to_y.fit(x_arr.reshape(-1, 1), y_arr) def _x_to_y(x: int) -> int: raw = x_to_y.predict(np.asarray([[x]]))[0].tolist() return int(round(raw)) y_to_x = LinearRegression() y_to_x.fit(y_arr.reshape(-1, 1), x_arr) def _y_to_x(y: int) -> int: raw = y_to_x.predict(np.asarray([[y]]))[0].tolist() return int(round(raw)) return _x_to_y, _y_to_x zoo/ccip/train_.py 0 → 100644 +219 −0 Original line number Diff line number Diff line import os import re from typing import Optional import torch import torch.nn.functional as F from accelerate import Accelerator from ditk import logging from hbutils.random import global_seed from torch.optim import lr_scheduler from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset from .loss import FocalLoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'ccip') _LOG_DIR = os.path.join(_TRAIN_DIR, 'logs') _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') _CKPT_PATTERN = re.compile(r'^ccip-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$') def _find_latest_ckpt(name: str) -> Optional[str]: if os.path.exists(_CKPT_DIR): ckpts = [] for filename in os.listdir(_CKPT_DIR): matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename)) if matching and matching.group('name') == name: ckpts.append((int(matching.group('epoch')), os.path.join(_CKPT_DIR, filename))) ckpts = sorted(ckpts) if ckpts: return ckpts[-1][1] else: return None else: return None def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: if filename is not None: matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename)) if matching: return int(matching.group('epoch')) else: return None else: return None def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 100, learning_rate: float = 0.001, weight_decay: float = 1e-3, preference: float = 0.0, save_per_epoch: int = 10, eval_epoch: int = 5, model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0): if seed is not None: # native random, numpy, torch and faker's seeds are includes # if you need to register more library for seeding, see: # https://hansbug.github.io/hbutils/main/api_doc/random/state.html#register-random-source global_seed(seed) accelerator = Accelerator( # mixed_precision=self.cfgs.mixed_precision, step_scheduler_with_optimizer=False, ) session_name = session_name or re.sub(r'\W+', '-', model_name) _log_dir = os.path.join(_LOG_DIR, session_name) if accelerator.is_local_main_process: os.makedirs(_log_dir, exist_ok=True) os.makedirs(_CKPT_DIR, exist_ok=True) writer = SummaryWriter(_log_dir) writer.add_custom_scalars({ "general": { "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]], "false": ["Multiline", ["test/fn", "test/fp", "train/fn", "train/fp"]], }, "test": { "false": ["Multiline", ["test/fn", "test/fp"]], }, "train": { "false": ["Multiline", ["train/fn", "train/fp"]], }, }) else: writer = None model = CCIP(model_name) preprocess = model.preprocess train_transform = Compose([*TRAIN_TRANSFORM.transforms, *preprocess.transforms]) test_transform = preprocess image_dataset = CCIPImagesDataset(dataset_dir) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio) train_image_dataset.transform = train_transform test_image_dataset.transform = test_transform train_dataset = CharacterDataset(train_image_dataset, group_size) test_dataset = CharacterDataset(test_image_dataset, group_size) if from_ckpt is None: from_ckpt = _find_latest_ckpt(session_name) previous_epoch = _ckpt_epoch(from_ckpt) or 0 if from_ckpt: logging.info(f'Load checkpoint from {from_ckpt!r}.') model.load_state_dict(torch.load(from_ckpt, map_location='cpu')) else: logging.info(f'No checkpoint found, new model will be used.') if preference < 0: loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference else: loss_weight = torch.as_tensor([1.0, torch.e]) ** preference loss_fn = FocalLoss(weight=loss_weight).to(accelerator.device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, steps_per_epoch=len(train_dataset), epochs=max_epochs, pct_start=0.15, final_div_factor=20. ) model, optimizer, train_dataset, test_dataset, scheduler = \ accelerator.prepare(model, optimizer, train_dataset, test_dataset, scheduler) for epoch in range(previous_epoch + 1, max_epochs + 1): running_loss = 0.0 train_correct, train_total = 0, 0 train_fp, train_fn = 0, 0 model.train() for i, (inputs, char_ids) in enumerate(tqdm(train_dataset)): inputs = inputs.float() inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix >= ix.reshape(-1, 1) # BxB, remove duplicated m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() train_fn += (preds[labels == 1] == 0).sum().item() train_total += labels.shape[0] loss = loss_fn(outputs, labels) accelerator.backward(loss) optimizer.step() running_loss += loss.item() * inputs.size(0) scheduler.step() epoch_loss = running_loss / train_total train_accuracy = train_correct / train_total train_fp_p = train_fp / train_total train_fn_p = train_fn / train_total if accelerator.is_local_main_process: logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, ' f'train accuracy: {train_accuracy:.4f}, ' f'false positive: {train_fp_p:.4f}, false negative: {train_fn_p:.4f}') if writer: writer.add_scalar('train/loss', epoch_loss, epoch) writer.add_scalar('train/accuracy', train_accuracy, epoch) writer.add_scalar('train/fp', train_fp_p, epoch) writer.add_scalar('train/fn', train_fn_p, epoch) model.eval() if epoch % eval_epoch == 0: with torch.no_grad(): test_correct, test_total = 0, 0 test_fp, test_fn = 0, 0 for i, (inputs, char_ids) in enumerate(tqdm(test_dataset)): inputs = inputs.float() inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix >= ix.reshape(-1, 1) # BxB, remove duplicated m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() train_fn += (preds[labels == 1] == 0).sum().item() train_total += labels.shape[0] test_accuracy = test_correct / test_total test_fp_p = test_fp / test_total test_fn_p = test_fn / test_total if accelerator.is_local_main_process: logging.info(f'Epoch {epoch}, test accuracy: {test_accuracy:.4f}, ' f'false positive: {test_fp_p:.4f}, false negative: {test_fn_p:.4f}') if writer: writer.add_scalar('test/accuracy', test_accuracy, epoch) writer.add_scalar('test/fp', test_fp_p, epoch) writer.add_scalar('test/fn', test_fn_p, epoch) if accelerator.is_local_main_process and epoch % save_per_epoch == 0: current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt') torch.save(model.state_dict(), current_ckpt_file) logging.info(f'Saved to {current_ckpt_file!r}.') Loading
zoo/ccip/__main__.py 0 → 100644 +16 −0 Original line number Diff line number Diff line from functools import partial import click from ..utils import GLOBAL_CONTEXT_SETTINGS from ..utils import print_version as _origin_print_version print_version = partial(_origin_print_version, 'zoo.ccip') @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}) @click.option('-v', '--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True, help="Utils with pixiv resources.") def cli(): pass # pragma: no cover
zoo/ccip/dataset.py +93 −13 Original line number Diff line number Diff line import glob import os.path from typing import List, Tuple import random from typing import List, Tuple, Dict import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from imgutils.data import load_image from .prob import get_reg_for_prob TRAIN_TRANSFORM = transforms.Compose([ transforms.Resize(300), transforms.RandomCrop(250, padding=25, pad_if_needed=True, padding_mode='reflect'), transforms.RandomRotation((-15, 15)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.10, 0.10, 0.10, 0.10), ]) TEST_TRANSFORM = transforms.Compose([]) class CCIPDataset(Dataset): def __init__(self, root_dir, transform=None): _ids, _maxid = {}, 0 self.items: List[Tuple[str, int]] = [] self.transform = transform for file in glob.glob(os.path.join(root_dir, '*', '*', '*.jpg')): dirname = os.path.normcase(os.path.normpath(os.path.dirname(os.path.abspath(file)))) if dirname not in _ids: _ids[dirname] = _maxid _maxid += 1 self.items.append((file, _ids[dirname])) class ImagesDataset(Dataset): def __init__(self, items: List[Tuple[str, int]], transform=None): self.items: List[Tuple[str, int]] = items self.transform = transform def __len__(self): return len(self.items) Loading @@ -32,3 +36,79 @@ class CCIPDataset(Dataset): image = self.transform(image) return image, idx def split_dataset(self, test_prob: float = 0.2): total = len(self.items) test_ids = set(random.sample(list(range(total)), k=int(total * test_prob))) train_items, test_items = [], [] for i, item in enumerate(self.items): if i in test_ids: test_items.append(item) else: train_items.append(item) return ImagesDataset(train_items, self.transform), ImagesDataset(test_items, self.transform) class CCIPImagesDataset(ImagesDataset): def __init__(self, root_dir, transform=None): _ids, _maxid = {}, 0 _items: List[Tuple[str, int]] = [] for file in glob.glob(os.path.join(root_dir, '*', '*', '*.jpg')): dirname = os.path.normcase(os.path.normpath(os.path.dirname(os.path.abspath(file)))) if dirname not in _ids: _ids[dirname] = _maxid _maxid += 1 _items.append((file, _ids[dirname])) ImagesDataset.__init__(self, _items, transform) class CharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, force_prob: bool = True): self.images_dataset = images_dataset self.groups: Dict[int, List[int]] = {} for i, (_, idx) in enumerate(self.images_dataset.items): if idx not in self.groups: self.groups[idx] = [] self.groups[idx].append(i) self.group_size = group_size self._id_map = list(self.groups.keys()) self._x_to_y, self._y_to_x = get_reg_for_prob(prob) self.force_prob = force_prob def __len__(self): return len(self.groups) def __getitem__(self, item): idx = self._id_map[item] current_samples = self._x_to_y(self.group_size) if current_samples > len(self.groups[idx]) and self.force_prob: total_samples = self._y_to_x(len(self.groups[idx])) current_samples = self._x_to_y(total_samples) ex_samples = total_samples - current_samples else: ex_samples = self.group_size - current_samples indices = [] indices.extend(random.sample(self.groups[idx], k=current_samples)) for _ in range(ex_samples): while True: t_idx = random.choice(list(self.groups.keys())) if t_idx != idx: break indices.append(random.choice(self.groups[t_idx])) random.shuffle(indices) images, labels = [], [] for i in indices: image, label = self.images_dataset[i] images.append(image) labels.append(label) return torch.stack(list(map(torch.as_tensor, images))), \ torch.stack(list(map(torch.as_tensor, labels)))
zoo/ccip/loss.py +3 −1 Original line number Diff line number Diff line Loading @@ -11,7 +11,9 @@ class FocalLoss(nn.Module): def __init__(self, weight=None, gamma=2., reduction='mean'): nn.Module.__init__(self) self.weight = torch.as_tensor(weight).float() if weight is not None else weight weight = torch.as_tensor(weight).float() if weight is not None else weight self.register_buffer('weight', weight) self.gamma = gamma self.reduction = reduction Loading
zoo/ccip/prob.py 0 → 100644 +54 −0 Original line number Diff line number Diff line from typing import Tuple, Callable import numpy as np from sklearn.linear_model import LinearRegression def _array_create(n, m): return [0 if i < m else i - m + 1 for i in range(0, n)] def _possibility(n, m): v = _array_create(n, m) same, not_same = 0, 0 for i in range(n): for j in range(i, n): if v[i] == v[j]: same += 1 else: not_same += 1 return same / (not_same + same) def _get_m_for_n(n, p=0.6): left, right = 0, n while left < right: middle = (left + right) // 2 if _possibility(n, middle) < p: left = middle + 1 else: right = middle return right def get_reg_for_prob(prob=0.5) -> Tuple[Callable[[int], int], Callable[[int], int]]: x_arr = np.asarray(range(2, 150)) y_arr = np.asarray([_get_m_for_n(i, p=prob) for i in x_arr]) x_to_y = LinearRegression() x_to_y.fit(x_arr.reshape(-1, 1), y_arr) def _x_to_y(x: int) -> int: raw = x_to_y.predict(np.asarray([[x]]))[0].tolist() return int(round(raw)) y_to_x = LinearRegression() y_to_x.fit(y_arr.reshape(-1, 1), x_arr) def _y_to_x(y: int) -> int: raw = y_to_x.predict(np.asarray([[y]]))[0].tolist() return int(round(raw)) return _x_to_y, _y_to_x
zoo/ccip/train_.py 0 → 100644 +219 −0 Original line number Diff line number Diff line import os import re from typing import Optional import torch import torch.nn.functional as F from accelerate import Accelerator from ditk import logging from hbutils.random import global_seed from torch.optim import lr_scheduler from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset from .loss import FocalLoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'ccip') _LOG_DIR = os.path.join(_TRAIN_DIR, 'logs') _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') _CKPT_PATTERN = re.compile(r'^ccip-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$') def _find_latest_ckpt(name: str) -> Optional[str]: if os.path.exists(_CKPT_DIR): ckpts = [] for filename in os.listdir(_CKPT_DIR): matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename)) if matching and matching.group('name') == name: ckpts.append((int(matching.group('epoch')), os.path.join(_CKPT_DIR, filename))) ckpts = sorted(ckpts) if ckpts: return ckpts[-1][1] else: return None else: return None def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: if filename is not None: matching = _CKPT_PATTERN.fullmatch(os.path.basename(filename)) if matching: return int(matching.group('epoch')) else: return None else: return None def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 100, learning_rate: float = 0.001, weight_decay: float = 1e-3, preference: float = 0.0, save_per_epoch: int = 10, eval_epoch: int = 5, model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0): if seed is not None: # native random, numpy, torch and faker's seeds are includes # if you need to register more library for seeding, see: # https://hansbug.github.io/hbutils/main/api_doc/random/state.html#register-random-source global_seed(seed) accelerator = Accelerator( # mixed_precision=self.cfgs.mixed_precision, step_scheduler_with_optimizer=False, ) session_name = session_name or re.sub(r'\W+', '-', model_name) _log_dir = os.path.join(_LOG_DIR, session_name) if accelerator.is_local_main_process: os.makedirs(_log_dir, exist_ok=True) os.makedirs(_CKPT_DIR, exist_ok=True) writer = SummaryWriter(_log_dir) writer.add_custom_scalars({ "general": { "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]], "false": ["Multiline", ["test/fn", "test/fp", "train/fn", "train/fp"]], }, "test": { "false": ["Multiline", ["test/fn", "test/fp"]], }, "train": { "false": ["Multiline", ["train/fn", "train/fp"]], }, }) else: writer = None model = CCIP(model_name) preprocess = model.preprocess train_transform = Compose([*TRAIN_TRANSFORM.transforms, *preprocess.transforms]) test_transform = preprocess image_dataset = CCIPImagesDataset(dataset_dir) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio) train_image_dataset.transform = train_transform test_image_dataset.transform = test_transform train_dataset = CharacterDataset(train_image_dataset, group_size) test_dataset = CharacterDataset(test_image_dataset, group_size) if from_ckpt is None: from_ckpt = _find_latest_ckpt(session_name) previous_epoch = _ckpt_epoch(from_ckpt) or 0 if from_ckpt: logging.info(f'Load checkpoint from {from_ckpt!r}.') model.load_state_dict(torch.load(from_ckpt, map_location='cpu')) else: logging.info(f'No checkpoint found, new model will be used.') if preference < 0: loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference else: loss_weight = torch.as_tensor([1.0, torch.e]) ** preference loss_fn = FocalLoss(weight=loss_weight).to(accelerator.device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, steps_per_epoch=len(train_dataset), epochs=max_epochs, pct_start=0.15, final_div_factor=20. ) model, optimizer, train_dataset, test_dataset, scheduler = \ accelerator.prepare(model, optimizer, train_dataset, test_dataset, scheduler) for epoch in range(previous_epoch + 1, max_epochs + 1): running_loss = 0.0 train_correct, train_total = 0, 0 train_fp, train_fn = 0, 0 model.train() for i, (inputs, char_ids) in enumerate(tqdm(train_dataset)): inputs = inputs.float() inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix >= ix.reshape(-1, 1) # BxB, remove duplicated m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() train_fn += (preds[labels == 1] == 0).sum().item() train_total += labels.shape[0] loss = loss_fn(outputs, labels) accelerator.backward(loss) optimizer.step() running_loss += loss.item() * inputs.size(0) scheduler.step() epoch_loss = running_loss / train_total train_accuracy = train_correct / train_total train_fp_p = train_fp / train_total train_fn_p = train_fn / train_total if accelerator.is_local_main_process: logging.info(f'Epoch [{epoch}/{max_epochs}], loss: {epoch_loss:.6f}, ' f'train accuracy: {train_accuracy:.4f}, ' f'false positive: {train_fp_p:.4f}, false negative: {train_fn_p:.4f}') if writer: writer.add_scalar('train/loss', epoch_loss, epoch) writer.add_scalar('train/accuracy', train_accuracy, epoch) writer.add_scalar('train/fp', train_fp_p, epoch) writer.add_scalar('train/fn', train_fn_p, epoch) model.eval() if epoch % eval_epoch == 0: with torch.no_grad(): test_correct, test_total = 0, 0 test_fp, test_fn = 0, 0 for i, (inputs, char_ids) in enumerate(tqdm(test_dataset)): inputs = inputs.float() inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix >= ix.reshape(-1, 1) # BxB, remove duplicated m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() train_fn += (preds[labels == 1] == 0).sum().item() train_total += labels.shape[0] test_accuracy = test_correct / test_total test_fp_p = test_fp / test_total test_fn_p = test_fn / test_total if accelerator.is_local_main_process: logging.info(f'Epoch {epoch}, test accuracy: {test_accuracy:.4f}, ' f'false positive: {test_fp_p:.4f}, false negative: {test_fn_p:.4f}') if writer: writer.add_scalar('test/accuracy', test_accuracy, epoch) writer.add_scalar('test/fp', test_fp_p, epoch) writer.add_scalar('test/fn', test_fn_p, epoch) if accelerator.is_local_main_process and epoch % save_per_epoch == 0: current_ckpt_file = os.path.join(_CKPT_DIR, f'ccip-{session_name}-{epoch}.ckpt') torch.save(model.state_dict(), current_ckpt_file) logging.info(f'Saved to {current_ckpt_file!r}.')