Loading zoo/ccip/caformer.py +3 −7 Original line number Diff line number Diff line Loading @@ -26,13 +26,9 @@ class CaformerBackbone(torch.nn.Module): return x def get_caformer(input_resolution: int = 224, heads: int = 32, feat_dims: int = 1024, **kwargs): transform = Compose([ Resize(input_resolution, interpolation=InterpolationMode.BICUBIC), CenterCrop(input_resolution), lambda x: x.convert('RGB'), ToTensor(), def get_caformer(input_resolution: int = 384, heads: int = 32, feat_dims: int = 1024, **kwargs): transform = [ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) ] return CaformerBackbone(input_resolution, heads, feat_dims, **kwargs), transform zoo/ccip/dataset.py +59 −7 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ import os.path import random from typing import List, Tuple, Dict import numpy as np import torch from PIL import Image from torch.utils.data import Dataset Loading @@ -11,14 +12,19 @@ from torchvision import transforms from imgutils.data import load_image from .prob import get_reg_for_prob TRAIN_TRANSFORM = transforms.Compose([ transforms.Resize(300), transforms.RandomCrop(250, padding=25, pad_if_needed=True, padding_mode='reflect'), TRAIN_TRANSFORM = [ transforms.Resize(416), transforms.RandomRotation((-15, 15)), transforms.RandomCrop(384), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.10, 0.10, 0.10, 0.10), ]) TEST_TRANSFORM = transforms.Compose([]) transforms.ToTensor(), ] TEST_TRANSFORM = [ transforms.Resize(416), transforms.CenterCrop(384), transforms.ToTensor(), ] class ImagesDataset(Dataset): Loading @@ -37,7 +43,7 @@ class ImagesDataset(Dataset): return image, idx def split_dataset(self, test_prob: float = 0.2): def split_dataset(self, test_prob: float = 0.2, train_transform=None, test_transform=None): total = len(self.items) test_ids = set(random.sample(list(range(total)), k=int(total * test_prob))) Loading @@ -48,7 +54,7 @@ class ImagesDataset(Dataset): else: train_items.append(item) return ImagesDataset(train_items, self.transform), ImagesDataset(test_items, self.transform) return ImagesDataset(train_items, train_transform or self.transform), ImagesDataset(test_items, test_transform or self.transform) class CCIPImagesDataset(ImagesDataset): Loading Loading @@ -116,3 +122,49 @@ class CharacterDataset(Dataset): return torch.stack(list(map(torch.as_tensor, images))), \ torch.stack(list(map(torch.as_tensor, labels))) class FastCharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, **kwargs): self.images_dataset = images_dataset groups: Dict[int, List[int]] = {} for i, (_, cid) in enumerate(self.images_dataset.items): if cid not in groups: groups[cid] = [] groups[cid].append(i) self.groups = {k:np.array(v) for k,v in groups.items()} self.group_size = group_size self.prob = prob*2 def reset(self): idxs = np.arange(0, len(self.images_dataset.items)) np.random.shuffle(idxs) self.idxs = idxs[:-(len(idxs)%self.group_size)] def __len__(self): return len(self.idxs) def __getitem__(self, item): image, cid = self.images_dataset[self.idxs[item]] n_same = int(self.prob) + int((self.prob-int(self.prob))<=(item%self.group_size/self.group_size)) image = [image] cid = [cid] if n_same>0: same_idxs = random.sample(self.groups[cid], k=n_same) for idx in same_idxs: img_i, cid_i = self.images_dataset[idx] image.append(img_i) cid.append(cid_i) return image, cid def char_collect_fn(batch): img_list, cid_list = [], [] for data in batch: img_list.extend(data[0]) cid_list.extend(data[1]) return torch.stack(img_list), torch.tensor(cid_list) No newline at end of file zoo/ccip/loss.py +22 −0 Original line number Diff line number Diff line Loading @@ -52,3 +52,25 @@ class NTXentLoss(nn.Module): pos_tensor = torch.stack(pos_items) return (pos_tensor.sum() + self.eps) / (pos_tensor.shape[0] + self.eps) class MLCELoss(nn.Module): def __init__(self, weight=None, reduction='mean', eps=1e-4): super().__init__() weight = torch.as_tensor(weight).float() if weight is not None else weight self.register_buffer('weight', weight) self.reduction = reduction self.eps = eps def forward(self, input_tensor, target_tensor): log_prob_raw = F.softmax(input_tensor, dim=1) same_mask = (target_tensor.unsqueeze(0) == target_tensor.unsqueeze(1)).long() # [B,B] same_mask_diag0 = same_mask - torch.diag_embed(torch.diag(same_mask)) # diag=0 log_prob_x = log_prob_raw.clone() log_prob_x[same_mask_diag0.bool()] = self.eps log_prob_x.diagonal().copy_((log_prob_raw*same_mask_diag0).sum(dim=1)) log_prob_x = log_prob_x + torch.diag_embed(torch.ones(len(target_tensor))*self.eps) y = torch.arange(0, len(target_tensor)) return F.nll_loss(log_prob_x.log(), y, weight=self.weight, reduction=self.reduction) No newline at end of file zoo/ccip/model.py +34 −19 Original line number Diff line number Diff line Loading @@ -2,19 +2,29 @@ import torch.nn import torch.nn.functional as F from PIL import Image from torch import nn import numpy as np from zoo.utils import get_testfile #from zoo.utils import get_testfile from .backbone import get_backbone class CCIPBatchMetrics(nn.Module): def __init__(self): nn.Module.__init__(self) self.sim = nn.CosineSimilarity(dim=-1) self.logit_scale = nn.Parameter(torch.ones([])*np.log(1/0.07)) #self.sim = nn.CosineSimilarity(dim=-1) def forward(self, x): # x: BxN x = self.sim(x, x.unsqueeze(1)) # BxB return x def forward(self, image_features): # x: BxN # normalized features image_features = image_features/image_features.norm(dim=1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale*image_features@image_features.t() logits_per_image = logits_per_image - torch.diag_embed(torch.diag(logits_per_image)) return logits_per_image class CCIPFeature(torch.nn.Module): Loading Loading @@ -46,17 +56,22 @@ class CCIP(torch.nn.Module): if __name__ == '__main__': image_files = [ get_testfile('6124220.jpg'), get_testfile('6125785.jpg'), get_testfile('6125901.jpg'), ] model = CCIP() data = torch.stack([ model.preprocess(Image.open(img)) for img in image_files ]) print(data.dtype, data.shape) print(F.softmax(model.forward(data), dim=-1)) # image_files = [ # get_testfile('6124220.jpg'), # get_testfile('6125785.jpg'), # get_testfile('6125901.jpg'), # ] # # model = CCIP() # data = torch.stack([ # model.preprocess(Image.open(img)) # for img in image_files # ]) # print(data.dtype, data.shape) # # print(F.softmax(model.forward(data), dim=-1)) data = torch.randn(4,3,384,384).cuda() model = CCIP('caformer').cuda() print(model(data)) zoo/ccip/train_.py +52 −32 Original line number Diff line number Diff line Loading @@ -12,10 +12,11 @@ from sklearn.metrics import accuracy_score from torch.optim import lr_scheduler from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from torch.utils.data import DataLoader from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM from .loss import NTXentLoss from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM, char_collect_fn from .loss import NTXentLoss, MLCELoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading Loading @@ -81,8 +82,8 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000): def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30, learning_rate: float = 0.001, weight_decay: float = 1e-3, tau: float = 0.15, save_per_epoch: int = 10, eval_epoch: int = 5, learning_rate: float = 0.001, weight_decay: float = 1e-2, tau: float = 0.15, save_per_epoch: int = 10, eval_epoch: int = 5, num_workers=8, model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0): if seed is not None: # native random, numpy, torch and faker's seeds are includes Loading Loading @@ -113,12 +114,15 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio model = CCIP(model_name) image_dataset = CCIPImagesDataset(dataset_dir) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio) train_image_dataset.transform = Compose([*TRAIN_TRANSFORM.transforms, *model.preprocess.transforms]) test_image_dataset.transform = Compose([*TEST_TRANSFORM.transforms, *model.preprocess.transforms]) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio, train_transform=Compose(TRAIN_TRANSFORM.transforms+model.preprocess.transforms), test_transform=Compose(TEST_TRANSFORM.transforms+model.preprocess.transforms),) train_dataset = CharacterDataset(train_image_dataset, group_size, force_prob=False) test_dataset = CharacterDataset(test_image_dataset, group_size) train_dataloader = DataLoader(train_dataset, batch_size=group_size, shuffle=True, num_workers=num_workers, collate_fn=char_collect_fn, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=group_size, num_workers=num_workers, collate_fn=char_collect_fn) if from_ckpt is None: from_ckpt = _find_latest_ckpt(session_name) Loading @@ -129,7 +133,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio else: logging.info(f'No checkpoint found, new model will be used.') loss_fn = NTXentLoss(tau=tau).to(accelerator.device) loss_fn = MLCELoss().to(accelerator.device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, Loading @@ -137,36 +141,52 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio pct_start=0.15, final_div_factor=20. ) model, optimizer, train_dataset, test_dataset, scheduler = \ accelerator.prepare(model, optimizer, train_dataset, test_dataset, scheduler) model, optimizer, train_dataloader, test_dataloader, scheduler = \ accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) test_dataloader.dataset.reset() for epoch in range(previous_epoch + 1, max_epochs + 1): running_loss = 0.0 train_pos_total = 0 positive_sims, negative_sims = [], [] model.train() for i, (inputs, char_ids) in enumerate(tqdm(train_dataset)): inputs = inputs.float() for i, (inputs, char_ids) in enumerate(tqdm(train_dataloader)): train_dataloader.dataset.reset() inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix > ix.reshape(-1, 1) # BxB, remove duplicated similarities = model(inputs) # BxB outputs = similarities[mask] # N labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device) # N # B = len(char_ids) # mask = torch.triu(torch.ones(B,B),diagonal=1).to(accelerator.device) # BxB, remove duplicated # similarities = model(inputs) # BxB # outputs = similarities[mask] # N # labels = (char_ids.view(-1,1) == char_ids.view(1,-1))[mask] # N # labels = char_ids outputs = model(inputs) # BxB labels = char_ids loss = loss_fn(outputs, labels) accelerator.backward(loss) optimizer.step() train_pos_total += labels.sum() running_loss += loss.item() * labels.sum() scheduler.step() positive_sims.append(outputs[labels]) negative_sims.append(outputs[~labels]) epoch_loss = running_loss / train_pos_total running_loss += loss.item()*len(char_ids) gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) # outputs.diagonal().copy_(torch.ones(len(char_ids))*-10000) # max_idxs = outputs.argsort(dim=-1) # for max_idx, n_pos in zip(max_idxs, gt.sum(dim=1)): # train_pos_total += n_pos # positive_sims.append(outputs[labels]) # negative_sims.append(outputs[~labels]) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) epoch_loss = running_loss #/ train_pos_total train_psims = torch.cat(positive_sims) train_nsims = torch.cat(negative_sims) train_pos_mean, train_pos_std, train_neg_mean, train_neg_std, train_threshold, train_acc_svm = \ Loading @@ -188,19 +208,19 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio if epoch % eval_epoch == 0: with torch.no_grad(): positive_sims, negative_sims = [], [] for i, (inputs, char_ids) in enumerate(tqdm(test_dataset)): inputs = inputs.float() for i, (inputs, char_ids) in enumerate(tqdm(test_dataloader)): inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix > ix.reshape(-1, 1) # BxB, remove duplicated similarities = model(inputs) # BxB outputs = similarities[mask] # N labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device) # N outputs = model(inputs) # BxB positive_sims.append(outputs[labels]) negative_sims.append(outputs[~labels]) gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) test_psims = torch.cat(positive_sims) test_nsims = torch.cat(negative_sims) Loading Loading
zoo/ccip/caformer.py +3 −7 Original line number Diff line number Diff line Loading @@ -26,13 +26,9 @@ class CaformerBackbone(torch.nn.Module): return x def get_caformer(input_resolution: int = 224, heads: int = 32, feat_dims: int = 1024, **kwargs): transform = Compose([ Resize(input_resolution, interpolation=InterpolationMode.BICUBIC), CenterCrop(input_resolution), lambda x: x.convert('RGB'), ToTensor(), def get_caformer(input_resolution: int = 384, heads: int = 32, feat_dims: int = 1024, **kwargs): transform = [ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) ] return CaformerBackbone(input_resolution, heads, feat_dims, **kwargs), transform
zoo/ccip/dataset.py +59 −7 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ import os.path import random from typing import List, Tuple, Dict import numpy as np import torch from PIL import Image from torch.utils.data import Dataset Loading @@ -11,14 +12,19 @@ from torchvision import transforms from imgutils.data import load_image from .prob import get_reg_for_prob TRAIN_TRANSFORM = transforms.Compose([ transforms.Resize(300), transforms.RandomCrop(250, padding=25, pad_if_needed=True, padding_mode='reflect'), TRAIN_TRANSFORM = [ transforms.Resize(416), transforms.RandomRotation((-15, 15)), transforms.RandomCrop(384), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.10, 0.10, 0.10, 0.10), ]) TEST_TRANSFORM = transforms.Compose([]) transforms.ToTensor(), ] TEST_TRANSFORM = [ transforms.Resize(416), transforms.CenterCrop(384), transforms.ToTensor(), ] class ImagesDataset(Dataset): Loading @@ -37,7 +43,7 @@ class ImagesDataset(Dataset): return image, idx def split_dataset(self, test_prob: float = 0.2): def split_dataset(self, test_prob: float = 0.2, train_transform=None, test_transform=None): total = len(self.items) test_ids = set(random.sample(list(range(total)), k=int(total * test_prob))) Loading @@ -48,7 +54,7 @@ class ImagesDataset(Dataset): else: train_items.append(item) return ImagesDataset(train_items, self.transform), ImagesDataset(test_items, self.transform) return ImagesDataset(train_items, train_transform or self.transform), ImagesDataset(test_items, test_transform or self.transform) class CCIPImagesDataset(ImagesDataset): Loading Loading @@ -116,3 +122,49 @@ class CharacterDataset(Dataset): return torch.stack(list(map(torch.as_tensor, images))), \ torch.stack(list(map(torch.as_tensor, labels))) class FastCharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, **kwargs): self.images_dataset = images_dataset groups: Dict[int, List[int]] = {} for i, (_, cid) in enumerate(self.images_dataset.items): if cid not in groups: groups[cid] = [] groups[cid].append(i) self.groups = {k:np.array(v) for k,v in groups.items()} self.group_size = group_size self.prob = prob*2 def reset(self): idxs = np.arange(0, len(self.images_dataset.items)) np.random.shuffle(idxs) self.idxs = idxs[:-(len(idxs)%self.group_size)] def __len__(self): return len(self.idxs) def __getitem__(self, item): image, cid = self.images_dataset[self.idxs[item]] n_same = int(self.prob) + int((self.prob-int(self.prob))<=(item%self.group_size/self.group_size)) image = [image] cid = [cid] if n_same>0: same_idxs = random.sample(self.groups[cid], k=n_same) for idx in same_idxs: img_i, cid_i = self.images_dataset[idx] image.append(img_i) cid.append(cid_i) return image, cid def char_collect_fn(batch): img_list, cid_list = [], [] for data in batch: img_list.extend(data[0]) cid_list.extend(data[1]) return torch.stack(img_list), torch.tensor(cid_list) No newline at end of file
zoo/ccip/loss.py +22 −0 Original line number Diff line number Diff line Loading @@ -52,3 +52,25 @@ class NTXentLoss(nn.Module): pos_tensor = torch.stack(pos_items) return (pos_tensor.sum() + self.eps) / (pos_tensor.shape[0] + self.eps) class MLCELoss(nn.Module): def __init__(self, weight=None, reduction='mean', eps=1e-4): super().__init__() weight = torch.as_tensor(weight).float() if weight is not None else weight self.register_buffer('weight', weight) self.reduction = reduction self.eps = eps def forward(self, input_tensor, target_tensor): log_prob_raw = F.softmax(input_tensor, dim=1) same_mask = (target_tensor.unsqueeze(0) == target_tensor.unsqueeze(1)).long() # [B,B] same_mask_diag0 = same_mask - torch.diag_embed(torch.diag(same_mask)) # diag=0 log_prob_x = log_prob_raw.clone() log_prob_x[same_mask_diag0.bool()] = self.eps log_prob_x.diagonal().copy_((log_prob_raw*same_mask_diag0).sum(dim=1)) log_prob_x = log_prob_x + torch.diag_embed(torch.ones(len(target_tensor))*self.eps) y = torch.arange(0, len(target_tensor)) return F.nll_loss(log_prob_x.log(), y, weight=self.weight, reduction=self.reduction) No newline at end of file
zoo/ccip/model.py +34 −19 Original line number Diff line number Diff line Loading @@ -2,19 +2,29 @@ import torch.nn import torch.nn.functional as F from PIL import Image from torch import nn import numpy as np from zoo.utils import get_testfile #from zoo.utils import get_testfile from .backbone import get_backbone class CCIPBatchMetrics(nn.Module): def __init__(self): nn.Module.__init__(self) self.sim = nn.CosineSimilarity(dim=-1) self.logit_scale = nn.Parameter(torch.ones([])*np.log(1/0.07)) #self.sim = nn.CosineSimilarity(dim=-1) def forward(self, x): # x: BxN x = self.sim(x, x.unsqueeze(1)) # BxB return x def forward(self, image_features): # x: BxN # normalized features image_features = image_features/image_features.norm(dim=1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale*image_features@image_features.t() logits_per_image = logits_per_image - torch.diag_embed(torch.diag(logits_per_image)) return logits_per_image class CCIPFeature(torch.nn.Module): Loading Loading @@ -46,17 +56,22 @@ class CCIP(torch.nn.Module): if __name__ == '__main__': image_files = [ get_testfile('6124220.jpg'), get_testfile('6125785.jpg'), get_testfile('6125901.jpg'), ] model = CCIP() data = torch.stack([ model.preprocess(Image.open(img)) for img in image_files ]) print(data.dtype, data.shape) print(F.softmax(model.forward(data), dim=-1)) # image_files = [ # get_testfile('6124220.jpg'), # get_testfile('6125785.jpg'), # get_testfile('6125901.jpg'), # ] # # model = CCIP() # data = torch.stack([ # model.preprocess(Image.open(img)) # for img in image_files # ]) # print(data.dtype, data.shape) # # print(F.softmax(model.forward(data), dim=-1)) data = torch.randn(4,3,384,384).cuda() model = CCIP('caformer').cuda() print(model(data))
zoo/ccip/train_.py +52 −32 Original line number Diff line number Diff line Loading @@ -12,10 +12,11 @@ from sklearn.metrics import accuracy_score from torch.optim import lr_scheduler from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from torch.utils.data import DataLoader from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM from .loss import NTXentLoss from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM, char_collect_fn from .loss import NTXentLoss, MLCELoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading Loading @@ -81,8 +82,8 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000): def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30, learning_rate: float = 0.001, weight_decay: float = 1e-3, tau: float = 0.15, save_per_epoch: int = 10, eval_epoch: int = 5, learning_rate: float = 0.001, weight_decay: float = 1e-2, tau: float = 0.15, save_per_epoch: int = 10, eval_epoch: int = 5, num_workers=8, model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0): if seed is not None: # native random, numpy, torch and faker's seeds are includes Loading Loading @@ -113,12 +114,15 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio model = CCIP(model_name) image_dataset = CCIPImagesDataset(dataset_dir) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio) train_image_dataset.transform = Compose([*TRAIN_TRANSFORM.transforms, *model.preprocess.transforms]) test_image_dataset.transform = Compose([*TEST_TRANSFORM.transforms, *model.preprocess.transforms]) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio, train_transform=Compose(TRAIN_TRANSFORM.transforms+model.preprocess.transforms), test_transform=Compose(TEST_TRANSFORM.transforms+model.preprocess.transforms),) train_dataset = CharacterDataset(train_image_dataset, group_size, force_prob=False) test_dataset = CharacterDataset(test_image_dataset, group_size) train_dataloader = DataLoader(train_dataset, batch_size=group_size, shuffle=True, num_workers=num_workers, collate_fn=char_collect_fn, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=group_size, num_workers=num_workers, collate_fn=char_collect_fn) if from_ckpt is None: from_ckpt = _find_latest_ckpt(session_name) Loading @@ -129,7 +133,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio else: logging.info(f'No checkpoint found, new model will be used.') loss_fn = NTXentLoss(tau=tau).to(accelerator.device) loss_fn = MLCELoss().to(accelerator.device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, Loading @@ -137,36 +141,52 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio pct_start=0.15, final_div_factor=20. ) model, optimizer, train_dataset, test_dataset, scheduler = \ accelerator.prepare(model, optimizer, train_dataset, test_dataset, scheduler) model, optimizer, train_dataloader, test_dataloader, scheduler = \ accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) test_dataloader.dataset.reset() for epoch in range(previous_epoch + 1, max_epochs + 1): running_loss = 0.0 train_pos_total = 0 positive_sims, negative_sims = [], [] model.train() for i, (inputs, char_ids) in enumerate(tqdm(train_dataset)): inputs = inputs.float() for i, (inputs, char_ids) in enumerate(tqdm(train_dataloader)): train_dataloader.dataset.reset() inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix > ix.reshape(-1, 1) # BxB, remove duplicated similarities = model(inputs) # BxB outputs = similarities[mask] # N labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device) # N # B = len(char_ids) # mask = torch.triu(torch.ones(B,B),diagonal=1).to(accelerator.device) # BxB, remove duplicated # similarities = model(inputs) # BxB # outputs = similarities[mask] # N # labels = (char_ids.view(-1,1) == char_ids.view(1,-1))[mask] # N # labels = char_ids outputs = model(inputs) # BxB labels = char_ids loss = loss_fn(outputs, labels) accelerator.backward(loss) optimizer.step() train_pos_total += labels.sum() running_loss += loss.item() * labels.sum() scheduler.step() positive_sims.append(outputs[labels]) negative_sims.append(outputs[~labels]) epoch_loss = running_loss / train_pos_total running_loss += loss.item()*len(char_ids) gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) # outputs.diagonal().copy_(torch.ones(len(char_ids))*-10000) # max_idxs = outputs.argsort(dim=-1) # for max_idx, n_pos in zip(max_idxs, gt.sum(dim=1)): # train_pos_total += n_pos # positive_sims.append(outputs[labels]) # negative_sims.append(outputs[~labels]) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) epoch_loss = running_loss #/ train_pos_total train_psims = torch.cat(positive_sims) train_nsims = torch.cat(negative_sims) train_pos_mean, train_pos_std, train_neg_mean, train_neg_std, train_threshold, train_acc_svm = \ Loading @@ -188,19 +208,19 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio if epoch % eval_epoch == 0: with torch.no_grad(): positive_sims, negative_sims = [], [] for i, (inputs, char_ids) in enumerate(tqdm(test_dataset)): inputs = inputs.float() for i, (inputs, char_ids) in enumerate(tqdm(test_dataloader)): inputs = inputs.to(accelerator.device) # BxCxHxW char_ids = char_ids.to(accelerator.device) # B ix = torch.arange(0, char_ids.shape[0]) mask = ix > ix.reshape(-1, 1) # BxB, remove duplicated similarities = model(inputs) # BxB outputs = similarities[mask] # N labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device) # N outputs = model(inputs) # BxB positive_sims.append(outputs[labels]) negative_sims.append(outputs[~labels]) gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu() outputs = outputs.detach().cpu() gt_diag0 = gt.clone() gt_diag0.diagonal().copy_(torch.zeros(len(char_ids))) train_pos_total += gt_diag0.sum() positive_sims.append(outputs[gt_diag0]) negative_sims.append(outputs[~gt]) test_psims = torch.cat(positive_sims) test_nsims = torch.cat(negative_sims) Loading