Loading zoo/ccip/demo.py +29 −9 Original line number Diff line number Diff line import argparse import os import torch from huggingface_hub import hf_hub_download from torchvision import transforms from tqdm.auto import tqdm from imgutils.data import load_image from .dataset import TEST_TRANSFORM Loading Loading @@ -56,6 +58,19 @@ class Infer: outputs = self.model(imgs) return outputs @torch.no_grad() def infer_batch(self, img_list, bs=8): feat_list = [] for i in tqdm(range(0, len(img_list), bs)): imgs = torch.stack(img_list[i:i+bs]).to(self.device) if self.args.fp16: imgs = imgs.half() feat = self.model.feature(imgs) feat_list.append(feat) feat_list = torch.cat(feat_list, dim=0) outputs = self.model.metrics(feat_list) return outputs @staticmethod def build_args(): parser = argparse.ArgumentParser(description='Stable Diffusion Training') Loading @@ -68,15 +83,20 @@ class Infer: if __name__ == '__main__': demo = Infer(Infer.build_args()) torch.set_printoptions(precision=2,sci_mode=False) imgs = [] imgs.append(demo.load_img( r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp')) imgs.append(demo.load_img( r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp')) imgs.append( demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp')) imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/c398774304db7cc737bb57fa2f380295.jpg')) imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/20221215165339_13707.png')) imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/e768ebbb4a116c8b85b39342d3348775.png')) # imgs.append(demo.load_img( # r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp')) # imgs.append(demo.load_img( # r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp')) # imgs.append( # demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp')) # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/c398774304db7cc737bb57fa2f380295.jpg')) # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/20221215165339_13707.png')) # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/e768ebbb4a116c8b85b39342d3348775.png')) root=r'E:\dataset\ccip_error_images' for path in os.listdir(root): print(path) imgs.append(demo.load_img(os.path.join(root,path))) pred = demo.infer_one(imgs) print(pred) Loading
zoo/ccip/demo.py +29 −9 Original line number Diff line number Diff line import argparse import os import torch from huggingface_hub import hf_hub_download from torchvision import transforms from tqdm.auto import tqdm from imgutils.data import load_image from .dataset import TEST_TRANSFORM Loading Loading @@ -56,6 +58,19 @@ class Infer: outputs = self.model(imgs) return outputs @torch.no_grad() def infer_batch(self, img_list, bs=8): feat_list = [] for i in tqdm(range(0, len(img_list), bs)): imgs = torch.stack(img_list[i:i+bs]).to(self.device) if self.args.fp16: imgs = imgs.half() feat = self.model.feature(imgs) feat_list.append(feat) feat_list = torch.cat(feat_list, dim=0) outputs = self.model.metrics(feat_list) return outputs @staticmethod def build_args(): parser = argparse.ArgumentParser(description='Stable Diffusion Training') Loading @@ -68,15 +83,20 @@ class Infer: if __name__ == '__main__': demo = Infer(Infer.build_args()) torch.set_printoptions(precision=2,sci_mode=False) imgs = [] imgs.append(demo.load_img( r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp')) imgs.append(demo.load_img( r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp')) imgs.append( demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp')) imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/c398774304db7cc737bb57fa2f380295.jpg')) imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/20221215165339_13707.png')) imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/e768ebbb4a116c8b85b39342d3348775.png')) # imgs.append(demo.load_img( # r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp')) # imgs.append(demo.load_img( # r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp')) # imgs.append( # demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp')) # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/c398774304db7cc737bb57fa2f380295.jpg')) # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/20221215165339_13707.png')) # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/e768ebbb4a116c8b85b39342d3348775.png')) root=r'E:\dataset\ccip_error_images' for path in os.listdir(root): print(path) imgs.append(demo.load_img(os.path.join(root,path))) pred = demo.infer_one(imgs) print(pred)