Loading zoo/ccip/demo.py +11 −4 Original line number Diff line number Diff line import argparse import torch from huggingface_hub import hf_hub_download from torchvision import transforms from imgutils.data import load_image Loading @@ -8,16 +9,21 @@ from .dataset import TEST_TRANSFORM from .model import CCIP def _load_remote_ckpt(remote_ckpt): return hf_hub_download('deepghs/ccip', remote_ckpt, repo_type='model') class Infer: def __init__(self, args, device='cuda'): def __init__(self, args, device=None): self.args = args self.device = device self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.model = CCIP(args.model_name).to(device) self.model.eval() if self.args.fp16: self.model = self.model.half() state = torch.load(args.ckpt) state = torch.load(args.ckpt or _load_remote_ckpt(args.remote_ckpt), map_location='cpu') try: self.model.load_state_dict(state) except: Loading @@ -43,7 +49,8 @@ class Infer: def build_args(): parser = argparse.ArgumentParser(description='Stable Diffusion Training') parser.add_argument('--model_name', type=str, default='caformer') parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2_fp16.ckpt') parser.add_argument('--ckpt', type=str, default='') parser.add_argument('--remote_ckpt', type=str, default='ccip-caformer-2_fp32.ckpt') parser.add_argument('--fp16', default=None, action="store_true") return parser.parse_args() Loading Loading
zoo/ccip/demo.py +11 −4 Original line number Diff line number Diff line import argparse import torch from huggingface_hub import hf_hub_download from torchvision import transforms from imgutils.data import load_image Loading @@ -8,16 +9,21 @@ from .dataset import TEST_TRANSFORM from .model import CCIP def _load_remote_ckpt(remote_ckpt): return hf_hub_download('deepghs/ccip', remote_ckpt, repo_type='model') class Infer: def __init__(self, args, device='cuda'): def __init__(self, args, device=None): self.args = args self.device = device self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.model = CCIP(args.model_name).to(device) self.model.eval() if self.args.fp16: self.model = self.model.half() state = torch.load(args.ckpt) state = torch.load(args.ckpt or _load_remote_ckpt(args.remote_ckpt), map_location='cpu') try: self.model.load_state_dict(state) except: Loading @@ -43,7 +49,8 @@ class Infer: def build_args(): parser = argparse.ArgumentParser(description='Stable Diffusion Training') parser.add_argument('--model_name', type=str, default='caformer') parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2_fp16.ckpt') parser.add_argument('--ckpt', type=str, default='') parser.add_argument('--remote_ckpt', type=str, default='ccip-caformer-2_fp32.ckpt') parser.add_argument('--fp16', default=None, action="store_true") return parser.parse_args() Loading