Commit 86e043d9 authored by minicom's avatar minicom
Browse files

Code simplification

parent 68c16fd9
Loading
Loading
Loading
Loading
+3 −12
Original line number Diff line number Diff line
@@ -202,20 +202,11 @@ def soft_empty_cache():
        torch.mps.empty_cache()


def DEVICE_SELECTOR(): return deepcopy(
def DEVICE_SELECTOR(not_supported:list[str]=[]): return deepcopy(
    {
        'type': 'selector',
        'options': AVAILABLE_DEVICES,
        'value': DEFAULT_DEVICE
    }
)


def DEVICE_SELECTOR_NO_DML(): return deepcopy(
    {
        'type': 'selector',
        'options': [opt for opt in AVAILABLE_DEVICES if not 'privateuseone' in opt],
        'value': DEFAULT_DEVICE if DEFAULT_DEVICE != "privateuseone" else "cpu"
        'options': [opt for opt in AVAILABLE_DEVICES if all(device not in opt for device in not_supported)],
        'value': DEFAULT_DEVICE if not any(DEFAULT_DEVICE in device for device in not_supported) else 'cpu'
    }
)

+3 −3
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from utils.registry import Registry
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED
from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
@@ -292,7 +292,7 @@ class LamaInpainterMPE(InpainterBase):
            ], 
            'value': 2048
        },
        'device': DEVICE_SELECTOR_NO_DML()
        'device': DEVICE_SELECTOR(not_supported=['privateuseone'])
    }

    download_file_list = [{
@@ -415,7 +415,7 @@ class LamaLarge(LamaInpainterMPE):
            ], 
            'value': 1536,
        },
        'device': DEVICE_SELECTOR_NO_DML(),
        'device': DEVICE_SELECTOR(not_supported=['privateuseone']),
        'precision': {
            'type': 'selector',
            'options': [
+1 −1
Original line number Diff line number Diff line
from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, DEFAULT_DEVICE, TextBlock, OCR
from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEFAULT_DEVICE, TextBlock, OCR
+1 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from utils.registry import Registry
OCR = Registry('OCR')
register_OCR = OCR.register_module

from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, LOGGER
from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER

class OCRBase(BaseModule):

+2 −2
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from typing import List
import numpy as np
from copy import deepcopy

from .base import DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, OCRBase, register_OCR, TextBlock
from .base import DEVICE_SELECTOR, OCRBase, register_OCR, TextBlock
from utils.textblock import collect_textblock_regions

mit_params = {
@@ -11,7 +11,7 @@ mit_params = {
        'options': [8, 16, 24, 32],
        'value': 16
    },
    'device': DEVICE_SELECTOR_NO_DML(),
    'device': DEVICE_SELECTOR(not_supported=['privateuseone']),
    'description': 'OCRMIT32px'
}