Commit 5285d60a authored by dzy7e's avatar dzy7e
Browse files

demo fp16

parent 94df68e4
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@ class Infer:

        self.model = CCIP(args.model_name).to(device)
        self.model.eval()
        if self.args.fp16:
            self.model = self.model.half()
        state = torch.load(args.ckpt)
        try:
            self.model.load_state_dict(state)
@@ -31,6 +33,8 @@ class Infer:
    @torch.no_grad()
    def infer_one(self, img_list):
        imgs = torch.stack(img_list).to(self.device)
        if self.args.fp16:
            imgs = imgs.half()
        outputs = self.model(imgs)
        return outputs

@@ -38,7 +42,8 @@ class Infer:
    def build_args():
        parser = argparse.ArgumentParser(description='Stable Diffusion Training')
        parser.add_argument('--model_name', type=str, default='caformer')
        parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2.ckpt')
        parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2_fp16.ckpt')
        parser.add_argument('--fp16', default=None, action="store_true")
        return parser.parse_args()

if __name__ == '__main__':