Commit 9d356e97 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix some bugs

parent 9ca6490b
Loading
Loading
Loading
Loading
+11 −4
Original line number Diff line number Diff line
import argparse

import torch
from huggingface_hub import hf_hub_download
from torchvision import transforms

from imgutils.data import load_image
@@ -8,16 +9,21 @@ from .dataset import TEST_TRANSFORM
from .model import CCIP


def _load_remote_ckpt(remote_ckpt):
    return hf_hub_download('deepghs/ccip', remote_ckpt, repo_type='model')


class Infer:
    def __init__(self, args, device='cuda'):
    def __init__(self, args, device=None):
        self.args = args
        self.device = device
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        self.model = CCIP(args.model_name).to(device)
        self.model.eval()
        if self.args.fp16:
            self.model = self.model.half()
        state = torch.load(args.ckpt)

        state = torch.load(args.ckpt or _load_remote_ckpt(args.remote_ckpt), map_location='cpu')
        try:
            self.model.load_state_dict(state)
        except:
@@ -43,7 +49,8 @@ class Infer:
    def build_args():
        parser = argparse.ArgumentParser(description='Stable Diffusion Training')
        parser.add_argument('--model_name', type=str, default='caformer')
        parser.add_argument('--ckpt', type=str, default='ckpts/ccip-caformer-2_fp16.ckpt')
        parser.add_argument('--ckpt', type=str, default='')
        parser.add_argument('--remote_ckpt', type=str, default='ccip-caformer-2_fp32.ckpt')
        parser.add_argument('--fp16', default=None, action="store_true")
        return parser.parse_args()