Commit 116aa465 authored by dzy7e's avatar dzy7e
Browse files

ccip clip

parent 22053872
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -134,7 +134,7 @@ class FastCharacterDataset(Dataset):
            if cid not in groups:
                groups[cid] = []
            groups[cid].append(i)
        self.groups = {k:np.array(v) for k,v in groups.items()}
        self.groups = groups

        self.group_size = group_size
        self.prob = prob*2
@@ -149,12 +149,12 @@ class FastCharacterDataset(Dataset):

    def __getitem__(self, item):
        image, cid = self.images_dataset[self.idxs[item]]
        n_same = int(self.prob) + int((self.prob-int(self.prob))<=(item%self.group_size/self.group_size))
        n_same = int(self.prob) + int((self.prob-int(self.prob))<=((item%self.group_size+1)/self.group_size))

        image = [image]
        cid = [cid]
        if n_same>0:
            same_idxs = random.sample(self.groups[cid], k=n_same)
            same_idxs = random.sample(self.groups[cid[0]], k=n_same)
            for idx in same_idxs:
                img_i, cid_i = self.images_dataset[idx]
                image.append(img_i)
+2 −2
Original line number Diff line number Diff line
@@ -70,7 +70,7 @@ class MLCELoss(nn.Module):
        log_prob_x = log_prob_raw.clone()
        log_prob_x[same_mask_diag0.bool()] = self.eps
        log_prob_x.diagonal().copy_((log_prob_raw*same_mask_diag0).sum(dim=1))
        log_prob_x = log_prob_x + torch.diag_embed(torch.ones(len(target_tensor))*self.eps)
        y = torch.arange(0, len(target_tensor))
        log_prob_x = log_prob_x + torch.diag_embed(torch.ones(len(target_tensor))*self.eps).to(input_tensor.device)
        y = torch.arange(0, len(target_tensor)).to(input_tensor.device)

        return F.nll_loss(log_prob_x.log(), y, weight=self.weight, reduction=self.reduction)
 No newline at end of file
+10 −7
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import re
from typing import Optional

import torch
from accelerate import Accelerator
from accelerate import Accelerator, DistributedDataParallelKwargs
from ditk import logging
from hbutils.random import global_seed
from sklearn import svm
@@ -15,7 +15,7 @@ from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM, char_collect_fn
from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, FastCharacterDataset, TEST_TRANSFORM, char_collect_fn
from .loss import NTXentLoss, MLCELoss
from .model import CCIP
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
@@ -91,9 +91,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        # https://hansbug.github.io/hbutils/main/api_doc/random/state.html#register-random-source
        global_seed(seed)

    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(
        # mixed_precision=self.cfgs.mixed_precision,
        step_scheduler_with_optimizer=False,
        kwargs_handlers=[ddp_kwargs],
    )

    session_name = session_name or re.sub(r'\W+', '-', model_name)
@@ -115,11 +117,13 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    model = CCIP(model_name)
    image_dataset = CCIPImagesDataset(dataset_dir)
    train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio,
                                    train_transform=Compose(TRAIN_TRANSFORM.transforms+model.preprocess.transforms),
                                    test_transform=Compose(TEST_TRANSFORM.transforms+model.preprocess.transforms),)
                                    train_transform=Compose(TRAIN_TRANSFORM+model.preprocess),
                                    test_transform=Compose(TEST_TRANSFORM+model.preprocess),)

    train_dataset = CharacterDataset(train_image_dataset, group_size, force_prob=False)
    test_dataset = CharacterDataset(test_image_dataset, group_size)
    train_dataset = FastCharacterDataset(train_image_dataset, group_size, force_prob=False)
    test_dataset = FastCharacterDataset(test_image_dataset, group_size)
    train_dataset.reset()
    test_dataset.reset()
    train_dataloader = DataLoader(train_dataset, batch_size=group_size, shuffle=True, num_workers=num_workers, collate_fn=char_collect_fn,
                                  drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=group_size, num_workers=num_workers, collate_fn=char_collect_fn)
@@ -143,7 +147,6 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

    model, optimizer, train_dataloader, test_dataloader, scheduler = \
        accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)
    test_dataloader.dataset.reset()

    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0