Unverified Commit 45d4369a authored by dmMaze's avatar dmMaze Committed by GitHub
Browse files

Merge pull request #15 from dmMaze/pytorch_dml_support

change backend semantics
parents 1b0a8ec0 48d096f7
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@ result
ballontranslator/data/models
ballontranslator/data/testpacks/eng_dontupload
ballontranslator/data/testpacks/testpacks
ballontranslator/data/*.png
release

tmp.py
+12 −10
Original line number Diff line number Diff line
@@ -112,7 +112,8 @@ class AOTInpainter(InpainterBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        },
@@ -158,9 +159,9 @@ class AOTInpainter(InpainterBase):
        mask_torch[mask_torch < 0.5] = 0
        mask_torch[mask_torch >= 0.5] = 1

        if self.device == 'cuda':
            img_torch = img_torch.cuda()
            mask_torch = mask_torch.cuda()
        if self.device != 'cpu':
            img_torch = img_torch.to(self.device)
            mask_torch = mask_torch.to(self.device)
        img_torch *= (1 - mask_torch)
        return img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right

@@ -217,7 +218,8 @@ class LamaInpainterMPE(InpainterBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        }
@@ -266,11 +268,11 @@ class LamaInpainterMPE(InpainterBase):
        rel_pos = torch.LongTensor(rel_pos).unsqueeze_(0)
        direct = torch.LongTensor(direct).unsqueeze_(0)

        if self.device == 'cuda':
            img_torch = img_torch.cuda()
            mask_torch = mask_torch.cuda()
            rel_pos = rel_pos.cuda()
            direct = direct.cuda()
        if self.device != 'cpu':
            img_torch = img_torch.to(self.device)
            mask_torch = mask_torch.to(self.device)
            rel_pos = rel_pos.to(self.device)
            direct = direct.to(self.device)
        img_torch *= (1 - mask_torch)
        return img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right

+6 −1
Original line number Diff line number Diff line
@@ -18,4 +18,9 @@ class ModuleParamParser:


import torch

if hasattr(torch, 'cuda'):
    DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
    DEFAULT_DEVICE = 'cpu'
+6 −3
Original line number Diff line number Diff line
@@ -63,7 +63,8 @@ class OCRMIT32px(OCRBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        },
@@ -115,7 +116,8 @@ class MangaOCR(OCRBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        }
@@ -174,7 +176,8 @@ class OCRMIT48pxCTC(OCRBase):
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
                'cuda',
                'hip'
            ],
            'select': DEFAULT_DEVICE
        },
+4 −4
Original line number Diff line number Diff line
@@ -402,8 +402,8 @@ class OCR48pxCTC:
        sd = torch.load(model_path, map_location = 'cpu')
        model.load_state_dict(sd['model'] if 'model' in sd else sd)
        model.eval()
        if self.device == 'cuda' :
            model = model.cuda()
        if self.device != 'cpu' :
            model = model.to(self.device)
        self.net = model

    def to(self, device: str) -> None:
@@ -437,8 +437,8 @@ class OCR48pxCTC:
                region[i, :, : W, :] = regions[idx]
            images = (torch.from_numpy(region).float() - 127.5) / 127.5
            images = einops.rearrange(images, 'N H W C -> N C H W')
            if self.device == 'cuda':
                images = images.cuda()
            if self.device != 'cpu':
                images = images.to(self.device)
            with torch.inference_mode() :
                texts = self.net.decode(images, widths, 0)
            for i, single_line in enumerate(texts) :
Loading