Commit 79ddc2cc authored by dmMaze's avatar dmMaze
Browse files

support mps acceleration for macos

parent d44b4b86
Loading
Loading
Loading
Loading
+23 −4
Original line number Diff line number Diff line
import gc
import os
from typing import Dict
from copy import deepcopy

from utils.logger import logger as LOGGER
import gc

GPUINTENSIVE_SET = {'cuda'}

@@ -42,16 +45,32 @@ class BaseModule:
            return True
        return False


import torch

if hasattr(torch, 'cuda'):
    DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps'):
    DEFAULT_DEVICE = 'mps'
else:
    DEFAULT_DEVICE = 'cpu'

def gc_collect():
    gc.collect()
    if DEFAULT_DEVICE == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    elif DEFAULT_DEVICE == 'mps':
        torch.mps.empty_cache()

DEVICE_SELECTOR = lambda : deepcopy(
    {
        'type': 'selector',
        'options': [
            'cpu',
            'cuda',
            'mps'
        ],
        'select': DEFAULT_DEVICE
    }
)
+3 −17
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from utils.registry import Registry
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..base import BaseModule, DEFAULT_DEVICE, gc_collect
from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
@@ -140,14 +140,7 @@ class AOTInpainter(InpainterBase):
            ], 
            'select': 2048
        }, 
        'device': {
            'type': 'selector',
            'options': [
                'cpu',
                'cuda',
            ],
            'select': DEFAULT_DEVICE
        },
        'device': DEVICE_SELECTOR(),
        'description': 'manga-image-translator inpainter'
    }

@@ -245,14 +238,7 @@ class LamaInpainterMPE(InpainterBase):
            ], 
            'select': 2048
        }, 
        'device': {
            'type': 'selector',
            'options': [
                'cpu',
                'cuda',
            ],
            'select': DEFAULT_DEVICE
        }
        'device': DEVICE_SELECTOR()
    }

    device = DEFAULT_DEVICE
+4 −25
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from utils.registry import Registry
OCR = Registry('OCR')
register_OCR = OCR.register_module

from ..base import BaseModule, DEFAULT_DEVICE
from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR

class OCRBase(BaseModule):

@@ -84,14 +84,7 @@ class OCRMIT32px(OCRBase):
            ],
            'select': 16
        },
        'device': {
            'type': 'selector',
            'options': [
                'cpu',
                'cuda',
            ],
            'select': DEFAULT_DEVICE
        },
        'device': DEVICE_SELECTOR(),
        'description': 'OCRMIT32px'
    }
    device = DEFAULT_DEVICE
@@ -132,14 +125,7 @@ MANGA_OCR_MODEL = None
@register_OCR('manga_ocr')
class MangaOCR(OCRBase):
    params = {
        'device': {
            'type': 'selector',
            'options': [
                'cpu',
                'cuda',
            ],
            'select': DEFAULT_DEVICE
        }
        'device': DEVICE_SELECTOR()
    }
    device = DEFAULT_DEVICE

@@ -202,14 +188,7 @@ class OCRMIT48pxCTC(OCRBase):
            ],
            'select': 16
        },
        'device': {
            'type': 'selector',
            'options': [
                'cpu',
                'cuda',
            ],
            'select': DEFAULT_DEVICE
        },
        'device': DEVICE_SELECTOR(),
        'description': 'mit48px_ctc'
    }
    device = DEFAULT_DEVICE
+2 −9
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ from utils.registry import Registry
TEXTDETECTORS = Registry('textdetectors')
register_textdetectors = TEXTDETECTORS.register_module

from ..base import BaseModule, DEFAULT_DEVICE
from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR

class TextDetectorBase(BaseModule):

@@ -51,14 +51,7 @@ class ComicTextDetector(TextDetectorBase):
            'options': [1, 2, 4, 6, 8, 12, 16, 24, 32], 
            'select': 4
        },
        'device': {
            'type': 'selector',
            'options': [
                'cpu',
                'cuda',
            ],
            'select': DEFAULT_DEVICE
        },
        'device': DEVICE_SELECTOR(),
        'description': 'ComicTextDetector'
    }

+1 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import time, requests, re, uuid, base64, hmac, functools, json

from .exceptions import InvalidSourceOrTargetLanguage, TranslatorSetupFailure, MissingTranslatorParams, TranslatorNotValid
from ..textdetector.textblock import TextBlock
from ..base import BaseModule
from ..base import BaseModule, DEVICE_SELECTOR
from utils.registry import Registry
from utils.io_utils import text_is_empty
from .hooks import chs2cht
Loading