Commit 73172d59 authored by John's avatar John
Browse files
parents 2b71da8d a7831af2
Loading
Loading
Loading
Loading
+12 −10
Original line number Diff line number Diff line
@@ -112,7 +112,8 @@ class AOTInpainter(InpainterBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        },
@@ -158,9 +159,9 @@ class AOTInpainter(InpainterBase):
        mask_torch[mask_torch < 0.5] = 0
        mask_torch[mask_torch >= 0.5] = 1

        if self.device == 'cuda':
            img_torch = img_torch.cuda()
            mask_torch = mask_torch.cuda()
        if self.device != 'cpu':
            img_torch = img_torch.to(self.device)
            mask_torch = mask_torch.to(self.device)
        img_torch *= (1 - mask_torch)
        return img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right

@@ -217,7 +218,8 @@ class LamaInpainterMPE(InpainterBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        }
@@ -266,11 +268,11 @@ class LamaInpainterMPE(InpainterBase):
        rel_pos = torch.LongTensor(rel_pos).unsqueeze_(0)
        direct = torch.LongTensor(direct).unsqueeze_(0)

        if self.device == 'cuda':
            img_torch = img_torch.cuda()
            mask_torch = mask_torch.cuda()
            rel_pos = rel_pos.cuda()
            direct = direct.cuda()
        if self.device != 'cpu':
            img_torch = img_torch.to(self.device)
            mask_torch = mask_torch.to(self.device)
            rel_pos = rel_pos.to(self.device)
            direct = direct.to(self.device)
        img_torch *= (1 - mask_torch)
        return img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right

+6 −1
Original line number Diff line number Diff line
@@ -18,4 +18,9 @@ class ModuleParamParser:


import torch

if hasattr(torch, 'cuda'):
    DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
    DEFAULT_DEVICE = 'cpu'
+6 −3
Original line number Diff line number Diff line
@@ -63,7 +63,8 @@ class OCRMIT32px(OCRBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        },
@@ -115,7 +116,8 @@ class MangaOCR(OCRBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        }
@@ -174,7 +176,8 @@ class OCRMIT48pxCTC(OCRBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        },
+4 −4
Original line number Diff line number Diff line
@@ -402,8 +402,8 @@ class OCR48pxCTC:
        sd = torch.load(model_path, map_location = 'cpu')
        model.load_state_dict(sd['model'] if 'model' in sd else sd)
        model.eval()
        if self.device == 'cuda' :
            model = model.cuda()
        if self.device != 'cpu' :
            model = model.to(self.device)
        self.net = model

    def to(self, device: str) -> None:
@@ -437,8 +437,8 @@ class OCR48pxCTC:
                region[i, :, : W, :] = regions[idx]
            images = (torch.from_numpy(region).float() - 127.5) / 127.5
            images = einops.rearrange(images, 'N H W C -> N C H W')
            if self.device == 'cuda':
                images = images.cuda()
            if self.device != 'cpu':
                images = images.to(self.device)
            with torch.inference_mode() :
                texts = self.net.decode(images, widths, 0)
            for i, single_line in enumerate(texts) :
+6 −6
Original line number Diff line number Diff line
@@ -553,8 +553,8 @@ class OCR32pxModel:
        sd = torch.load(model_path, map_location = 'cpu')
        model.load_state_dict(sd['model'] if 'model' in sd else sd)
        model.eval()
        if device == 'cuda':
            model = model.cuda()
        if device != 'cpu':
            model = model.to(device)
        self.net = model

    def to(self, device: str) -> None:
@@ -586,8 +586,8 @@ class OCR32pxModel:
                region[i, :, : W, :] = regions[idx]
            images = (torch.from_numpy(region).float() - 127.5) / 127.5
            images = einops.rearrange(images, 'N H W C -> N C H W')
            if self.device == 'cuda':
                images = images.cuda()
            if self.device != 'cpu':
                images = images.to(self.device)
            ret = self.net.infer_beam_batch(images, widths, beams_k = 5, max_seq_length = 255)
            
            for i, (pred_chars_index, prob, fr, fg, fb, br, bg, bb) in enumerate(ret) :
@@ -629,8 +629,8 @@ class OCR32pxModel:
        widths = [img.shape[1]]
        img = (torch.from_numpy(img[np.newaxis, ...]).float() - 127.5) / 127.5
        img = einops.rearrange(img, 'N H W C -> N C H W')
        if self.device == 'cuda':
            images = images.cuda()
        if self.device != 'cpu':
            images = images.to(self.device)
        ret = self.net.infer_beam_batch(img, widths, beams_k = 5, max_seq_length = 255)
        for i, (pred_chars_index, prob, fr, fg, fb, br, bg, bb) in enumerate(ret) :
            if prob < 0.5 :
Loading