Commit bb616788 authored by dmMaze's avatar dmMaze
Browse files

inpainting fall back to cpu when CUDA OOM, fix #262

parent fc1ae361
Loading
Loading
Loading
Loading
+35 −9
Original line number Diff line number Diff line
@@ -34,6 +34,28 @@ class InpainterBase(BaseModule):
    def setup_inpainter(self):
        raise NotImplementedError
    
    def memory_safe_inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        '''
        handle cuda out of memory
        '''
        try:
            return self._inpaint(img, mask, textblock_list)
        except Exception as e:
            if DEFAULT_DEVICE == 'cuda' and isinstance(e, torch.cuda.OutOfMemoryError):
                gc_collect()
                try:
                    return self._inpaint(img, mask, textblock_list)
                except Exception as ee:
                    if isinstance(ee, torch.cuda.OutOfMemoryError):
                        self.logger.warning(f'CUDA out of memory while calling {self.name}, fall back to cpu...\n\
                                            if running into it frequently, consider lowering the inpaint_size')
                        self.moveToDevice('cpu')
                        inpainted = self._inpaint(img, mask, textblock_list)
                        self.moveToDevice('cuda')
                        return inpainted
            else:
                raise e

    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None, check_need_inpaint: bool = False) -> np.ndarray:
        if not self.inpaint_by_block or textblock_list is None:
            if check_need_inpaint:
@@ -49,14 +71,7 @@ class InpainterBase(BaseModule):
                        img = img.copy()
                        img[np.where(ballon_msk > 0)] = average_bg_color
                        return img
            try:
                return self._inpaint(img, mask)
            except Exception as e:
                if isinstance(e, torch.cuda.OutOfMemoryError):
                    gc_collect()
                    return self._inpaint(img, mask)
                else:
                    raise e
            return self.memory_safe_inpaint(img, mask, textblock_list)
        else:
            im_h, im_w = img.shape[:2]
            inpainted = np.copy(img)
@@ -84,7 +99,7 @@ class InpainterBase(BaseModule):
                        # cv2.waitKey(0)
                
                if need_inpaint:
                    inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]] = self._inpaint(im, msk)
                    inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]] = self.memory_safe_inpaint(im, msk)

                mask[xyxy[1]:xyxy[3], xyxy[0]:xyxy[2]] = 0
            return inpainted
@@ -92,6 +107,9 @@ class InpainterBase(BaseModule):
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        raise NotImplementedError
    
    def moveToDevice(self, device: str):
        raise not NotImplementedError


@register_inpainter('opencv-tela')
class OpenCVInpainter(InpainterBase):
@@ -163,6 +181,10 @@ class AOTInpainter(InpainterBase):
            self.model.to(self.device)
        self.inpaint_size = int(self.params['inpaint_size']['select'])

    def moveToDevice(self, device: str):
        self.model.to(device)
        self.device = device

    def inpaint_preprocess(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:

        img_original = np.copy(img)
@@ -244,6 +266,10 @@ class LamaInpainterMPE(InpainterBase):
    device = DEFAULT_DEVICE
    inpaint_size = 2048

    def moveToDevice(self, device: str):
        self.model.to(device)
        self.device = device

    def setup_inpainter(self):
        global LAMA_MPE