Commit 9ca6490b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): reformat code

parent 5285d60a
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -16,3 +16,4 @@ accelerate
timm
ftfy
regex
torchmetrics
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
import torch.nn
from torchvision.transforms import InterpolationMode, Compose, Resize, CenterCrop, ToTensor, Normalize
from torchvision.transforms import Normalize

from .attention_pool import AttentionPool2d
from ..monochrome.metaformer import CAFormerBuilder
+13 −12
Original line number Diff line number Diff line
import glob
import os.path
import random
from typing import List, Tuple, Dict
@@ -54,7 +53,8 @@ class ImagesDataset(Dataset):
            else:
                train_items.append(item)

        return ImagesDataset(train_items, train_transform or self.transform), ImagesDataset(test_items, test_transform or self.transform)
        return ImagesDataset(train_items, train_transform or self.transform), \
            ImagesDataset(test_items, test_transform or self.transform)


class CCIPImagesDataset(ImagesDataset):
@@ -165,6 +165,7 @@ class FastCharacterDataset(Dataset):

        return image, cid


def char_collect_fn(batch):
    img_list, cid_list = [], []
    for data in batch:
+12 −7
Original line number Diff line number Diff line
@@ -3,9 +3,10 @@ import argparse
import torch
from torchvision import transforms

from imgutils.data import load_image
from .dataset import TEST_TRANSFORM
from .model import CCIP
from imgutils.data import load_image


class Infer:
    def __init__(self, args, device='cuda'):
@@ -46,12 +47,16 @@ class Infer:
        parser.add_argument('--fp16', default=None, action="store_true")
        return parser.parse_args()


if __name__ == '__main__':
    demo = Infer(Infer.build_args())
    imgs = []
    imgs.append(demo.load_img(r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp'))
    imgs.append(demo.load_img(
        r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp'))
    imgs.append(demo.load_img(
        r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp'))
    imgs.append(
        demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/c398774304db7cc737bb57fa2f380295.jpg'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/20221215165339_13707.png'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/e768ebbb4a116c8b85b39342d3348775.png'))
+6 −5
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@ class NTXentLoss(nn.Module):
        pos_tensor = torch.stack(pos_items)
        return (pos_tensor.sum() + self.eps) / (pos_tensor.shape[0] + self.eps)


class MLCELoss(nn.Module):
    def __init__(self, weight=None, reduction='mean', eps=1e-4):
        super().__init__()
Loading