Commit 83928a96 authored by minicom's avatar minicom
Browse files

remove dml option where dml not supported

parent 8b44f917
Loading
Loading
Loading
Loading
+17 −4
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from utils.logger import logger as LOGGER
from utils import shared


GPUINTENSIVE_SET = {'cuda', 'mps'}
GPUINTENSIVE_SET = {'mps', 'privateuseone'}

def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callable, Dict]):
    if callbacks is None:
@@ -189,13 +189,26 @@ def soft_empty_cache():
    elif DEFAULT_DEVICE == 'mps':
        torch.mps.empty_cache()

DEVICE_SELECTOR = lambda : deepcopy(

def DEVICE_SELECTOR(): return deepcopy(
    {
        'type': 'selector',
        'options': [
            'cpu',
            'mps',
            'privateuseone'
        ],
        'value': DEFAULT_DEVICE
    }
)


def DEVICE_SELECTOR_NO_DML(): return deepcopy(
    {
        'type': 'selector',
        'options': [
            'cpu',
            'cuda',
            'mps'
            'mps',
        ],
        'value': DEFAULT_DEVICE
    }
+5 −5
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from utils.registry import Registry
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED
from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
@@ -292,7 +292,7 @@ class LamaInpainterMPE(InpainterBase):
            ], 
            'value': 2048
        },
        'device': DEVICE_SELECTOR()
        'device': DEVICE_SELECTOR_NO_DML()
    }

    download_file_list = [{
@@ -415,7 +415,7 @@ class LamaLarge(LamaInpainterMPE):
            ], 
            'value': 1536,
        },
        'device': DEVICE_SELECTOR(),
        'device': DEVICE_SELECTOR_NO_DML(),
        'precision': {
            'type': 'selector',
            'options': [
+1 −1
Original line number Diff line number Diff line
from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEFAULT_DEVICE, TextBlock, OCR
 No newline at end of file
from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, DEFAULT_DEVICE, TextBlock, OCR
+1 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from utils.registry import Registry
OCR = Registry('OCR')
register_OCR = OCR.register_module

from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER
from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, LOGGER

class OCRBase(BaseModule):

+6 −2
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from typing import List
import numpy as np
from copy import deepcopy

from .base import DEVICE_SELECTOR, OCRBase, register_OCR, TextBlock
from .base import DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, OCRBase, register_OCR, TextBlock
from utils.textblock import collect_textblock_regions

mit_params = {
@@ -15,6 +15,10 @@ mit_params = {
    'description': 'OCRMIT32px'
}

mit_params_no_dml = deepcopy(mit_params)
mit_params_no_dml['device'] = DEVICE_SELECTOR_NO_DML()


class MITModels(OCRBase):

    _line_only = True
@@ -60,7 +64,7 @@ from .mit48px_ctc import OCR48pxCTC
@register_OCR('mit48px_ctc')
class OCRMIT48pxCTC(MITModels):

    params = deepcopy(mit_params)
    params = deepcopy(mit_params_no_dml)
    download_file_list = [{
        'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip',
        'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'],