Loading zoo/ccip/demo.py +6 −1 Original line number Diff line number Diff line Loading @@ -14,6 +14,8 @@ class Infer: self.model = CCIP(args.model_name).to(device) self.model.eval() if self.args.fp16: self.model = self.model.half() state = torch.load(args.ckpt) try: self.model.load_state_dict(state) Loading @@ -31,6 +33,8 @@ class Infer: @torch.no_grad() def infer_one(self, img_list): imgs = torch.stack(img_list).to(self.device) if self.args.fp16: imgs = imgs.half() outputs = self.model(imgs) return outputs Loading @@ -38,7 +42,8 @@ class Infer: def build_args(): parser = argparse.ArgumentParser(description='Stable Diffusion Training') parser.add_argument('--model_name', type=str, default='caformer') parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2.ckpt') parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2_fp16.ckpt') parser.add_argument('--fp16', default=None, action="store_true") return parser.parse_args() if __name__ == '__main__': Loading Loading
zoo/ccip/demo.py +6 −1 Original line number Diff line number Diff line Loading @@ -14,6 +14,8 @@ class Infer: self.model = CCIP(args.model_name).to(device) self.model.eval() if self.args.fp16: self.model = self.model.half() state = torch.load(args.ckpt) try: self.model.load_state_dict(state) Loading @@ -31,6 +33,8 @@ class Infer: @torch.no_grad() def infer_one(self, img_list): imgs = torch.stack(img_list).to(self.device) if self.args.fp16: imgs = imgs.half() outputs = self.model(imgs) return outputs Loading @@ -38,7 +42,8 @@ class Infer: def build_args(): parser = argparse.ArgumentParser(description='Stable Diffusion Training') parser.add_argument('--model_name', type=str, default='caformer') parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2.ckpt') parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2_fp16.ckpt') parser.add_argument('--fp16', default=None, action="store_true") return parser.parse_args() if __name__ == '__main__': Loading