Commit e742c4ae authored by dmMaze's avatar dmMaze
Browse files

fix mps

parent 79ddc2cc
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from copy import deepcopy

from utils.logger import logger as LOGGER

GPUINTENSIVE_SET = {'cuda'}
GPUINTENSIVE_SET = {'cuda', 'mps'}

class BaseModule:

@@ -45,14 +45,14 @@ class BaseModule:
            return True
        return False

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch

if hasattr(torch, 'cuda'):
DEFAULT_DEVICE = 'cpu'
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps'):
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEFAULT_DEVICE = 'mps'
else:
    DEFAULT_DEVICE = 'cpu'

def gc_collect():
    gc.collect()
+5 −11
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from utils.registry import Registry
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR
from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR, GPUINTENSIVE_SET
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
@@ -156,7 +156,7 @@ class AOTInpainter(InpainterBase):
        else:
            self.model = AOTMODEL
            self.model.to(self.device)
        self.inpaint_by_block = True if self.device == 'cuda' else False
        self.inpaint_by_block = self.device not in GPUINTENSIVE_SET
        self.inpaint_size = int(self.params['inpaint_size']['select'])

    def inpaint_preprocess(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
@@ -214,10 +214,7 @@ class AOTInpainter(InpainterBase):
            param_device = self.params['device']['select']
            self.model.to(param_device)
            self.device = param_device
            if param_device == 'cuda':
                self.inpaint_by_block = False
            else:
                self.inpaint_by_block = True
            self.inpaint_by_block = param_device not in GPUINTENSIVE_SET

        elif param_key == 'inpaint_size':
            self.inpaint_size = int(self.params['inpaint_size']['select'])
@@ -253,7 +250,7 @@ class LamaInpainterMPE(InpainterBase):
        else:
            self.model = LAMA_MPE
            self.model.to(self.device)
        self.inpaint_by_block = True if self.device == 'cuda' else False
        self.inpaint_by_block = self.device not in GPUINTENSIVE_SET
        self.inpaint_size = int(self.params['inpaint_size']['select'])

    def inpaint_preprocess(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
@@ -318,10 +315,7 @@ class LamaInpainterMPE(InpainterBase):
            param_device = self.params['device']['select']
            self.model.to(param_device)
            self.device = param_device
            if param_device == 'cuda':
                self.inpaint_by_block = False
            else:
                self.inpaint_by_block = True
            self.inpaint_by_block = param_device not in GPUINTENSIVE_SET

        elif param_key == 'inpaint_size':
            self.inpaint_size = int(self.params['inpaint_size']['select'])