Commit ee48ae64 authored by dzy7e's avatar dzy7e
Browse files

infer batch

parent 7413682e
Loading
Loading
Loading
Loading
+29 −9
Original line number Diff line number Diff line
import argparse
import os

import torch
from huggingface_hub import hf_hub_download
from torchvision import transforms
from tqdm.auto import tqdm

from imgutils.data import load_image
from .dataset import TEST_TRANSFORM
@@ -56,6 +58,19 @@ class Infer:
        outputs = self.model(imgs)
        return outputs

    @torch.no_grad()
    def infer_batch(self, img_list, bs=8):
        feat_list = []
        for i in tqdm(range(0, len(img_list), bs)):
            imgs = torch.stack(img_list[i:i+bs]).to(self.device)
            if self.args.fp16:
                imgs = imgs.half()
            feat = self.model.feature(imgs)
            feat_list.append(feat)
        feat_list = torch.cat(feat_list, dim=0)
        outputs = self.model.metrics(feat_list)
        return outputs

    @staticmethod
    def build_args():
        parser = argparse.ArgumentParser(description='Stable Diffusion Training')
@@ -68,15 +83,20 @@ class Infer:

if __name__ == '__main__':
    demo = Infer(Infer.build_args())
    torch.set_printoptions(precision=2,sci_mode=False)
    imgs = []
    imgs.append(demo.load_img(
        r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp'))
    imgs.append(demo.load_img(
        r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp'))
    imgs.append(
        demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/c398774304db7cc737bb57fa2f380295.jpg'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/20221215165339_13707.png'))
    imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/e768ebbb4a116c8b85b39342d3348775.png'))
    # imgs.append(demo.load_img(
    #     r'E:\dataset\pixiv\ganyu/11ee873afc5aacff2fd96248c1820c9240e922f6.jpg@942w_1320h_progressive.webp'))
    # imgs.append(demo.load_img(
    #     r'E:\dataset\pixiv\ganyu/91039559171fd81f1ccb54838e1f546a4c3d6e7c.jpg@942w_942h_progressive.webp'))
    # imgs.append(
    #     demo.load_img(r'E:\dataset\pixiv\p1/eb7009f1dd5ecc61cf8d55f7d82c1922487b3cfc.jpg@942w_1338h_progressive.webp'))
    # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/c398774304db7cc737bb57fa2f380295.jpg'))
    # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/20221215165339_13707.png'))
    # imgs.append(demo.load_img(r'E:\dataset\pixiv\p1/e768ebbb4a116c8b85b39342d3348775.png'))
    root=r'E:\dataset\ccip_error_images'
    for path in os.listdir(root):
        print(path)
        imgs.append(demo.load_img(os.path.join(root,path)))
    pred = demo.infer_one(imgs)
    print(pred)