Commit a660d895 authored by dmMaze's avatar dmMaze
Browse files

simplify config io

parent 8ea55f1e
Loading
Loading
Loading
Loading
+21 −4
Original line number Diff line number Diff line
from typing import List, Union, Tuple
import json

from qtpy.QtWidgets import QLayout, QHBoxLayout, QVBoxLayout, QTreeView, QWidget, QLabel, QSizePolicy, QSpacerItem, QCheckBox, QSplitter, QScrollArea, QGroupBox, QLineEdit
from qtpy.QtCore import Qt, QModelIndex, Signal, QSize
from qtpy.QtGui import QStandardItem, QStandardItemModel, QMouseEvent, QFont, QColor, QPalette
from PyQt5 import QtCore
from typing import List, Union, Tuple

from utils.logger import logger as LOGGER
from .stylewidgets import Widget, ConfigComboBox
from .misc import ProgramConfig, DLModuleConfig
from .constants import CONFIG_FONTSIZE_CONTENT, CONFIG_FONTSIZE_HEADER, CONFIG_FONTSIZE_TABLE, CONFIG_COMBOBOX_SHORT, CONFIG_COMBOBOX_LONG, CONFIG_COMBOBOX_MIDEAN
from .constants import CONFIG_PATH, CONFIG_FONTSIZE_CONTENT, CONFIG_FONTSIZE_HEADER, CONFIG_FONTSIZE_TABLE, CONFIG_COMBOBOX_SHORT, CONFIG_COMBOBOX_LONG, CONFIG_COMBOBOX_MIDEAN
from .dlconfig_parse_widgets import InpaintConfigPanel, TextDetectConfigPanel, TranslatorConfigPanel, OCRConfigPanel

class ConfigTextLabel(QLabel):
@@ -257,7 +260,16 @@ class GeneralPanel(QWidget):
class ConfigPanel(Widget):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        try:
            with open(CONFIG_PATH, 'r', encoding='utf8') as f:
                config_dict = json.loads(f.read())
            self.config = ProgramConfig(**config_dict)
        except Exception as e:
            LOGGER.exception(e)
            LOGGER.warning("Failed to load config file, using default config")
            self.config = ProgramConfig()

        self.configTable = ConfigTable()
        self.configTable.tableitem_pressed.connect(self.onTableItemPressed)
        self.configContent = ConfigContent()
@@ -317,13 +329,15 @@ class ConfigPanel(Widget):
        generalConfigPanel.addTextLabel(label_lettering)
        dec_program_str = self.tr('decide by program')
        use_global_str = self.tr('use global setting')
        to_uppercase_str = self.tr('To uppercase')
        self.let_fntsize_combox = generalConfigPanel.addCombobox([dec_program_str, use_global_str], self.tr('font size'))
        self.let_fntsize_combox.currentIndexChanged.connect(self.on_fntsize_flag_changed)
        self.let_fntstroke_combox = generalConfigPanel.addCombobox([dec_program_str, use_global_str], self.tr('stroke'))
        self.let_fntstroke_combox.currentIndexChanged.connect(self.on_fntstroke_flag_changed)
        self.let_fntcolor_combox = generalConfigPanel.addCombobox([dec_program_str, use_global_str], self.tr('font & stroke color'))
        self.let_fntcolor_combox.currentIndexChanged.connect(self.on_fontcolor_flag_changed)

        self.let_uppercase_checker = generalConfigPanel.addCheckBox(to_uppercase_str)
        self.let_uppercase_checker.stateChanged.connect(self.on_uppercase_changed)
        splitter = QSplitter(Qt.Orientation.Horizontal)
        splitter.addWidget(self.configTable)
        splitter.addWidget(self.configContent)
@@ -361,6 +375,9 @@ class ConfigPanel(Widget):
    def on_fntstroke_flag_changed(self):
        self.config.let_fntstroke_flag = self.let_fntstroke_combox.currentIndex()

    def on_uppercase_changed(self):
        self.config.let_uppercase_flag = self.let_uppercase_checker.isChecked()

    def on_fontcolor_flag_changed(self):
        self.config.let_fntcolor_flag = self.let_fntcolor_combox.currentIndex()

+2 −9
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from qtpy.QtCore import Qt, QPoint, QSize
from qtpy.QtGui import QGuiApplication, QIcon, QCloseEvent, QKeySequence, QImage, QPainter

from utils.logger import logger as LOGGER
from utils.io_utils import json_dump_nested_obj
from .misc import ProjImgTrans, ndarray2pixmap, pixmap2ndarray
from .canvas import Canvas
from .configpanel import ConfigPanel
@@ -130,13 +131,6 @@ class MainWindow(QMainWindow):
    def setupConfig(self):
        with open(STYLESHEET_PATH, "r", encoding='utf-8') as f:
            self.setStyleSheet(f.read())
        try:
            with open(CONFIG_PATH, 'r', encoding='utf8') as f:
                config_dict = json.loads(f.read())
            self.config.load_from_dict(config_dict)
        except Exception as e:
            LOGGER.exception(e)
            LOGGER.warning("Failed to load config file, using default config")

        self.bottomBar.originalSlider.setValue(self.config.original_transparency * 100)
        self.drawingPanel.maskTransperancySlider.setValue(self.config.mask_transparency * 100)
@@ -236,9 +230,8 @@ class MainWindow(QMainWindow):
        self.config.mask_transparency = self.canvas.mask_transparency
        self.config.original_transparency = self.canvas.original_transparency
        self.config.drawpanel = self.drawingPanel.get_config()
        config_dict = self.config.to_dict()
        with open(CONFIG_PATH, 'w', encoding='utf8') as f:
            f.write(json.dumps(config_dict, ensure_ascii=False))
            f.write(json_dump_nested_obj(self.config))
        if not self.imgtrans_proj.is_empty:
            self.imgtrans_proj.save()
        return super().closeEvent(event)
+88 −137
Original line number Diff line number Diff line
@@ -86,30 +86,52 @@ class InvalidProgramConfigException(Exception):
    pass


PROJTYPE_IMGTRANS = 'imgtrans'
PROJTYPE_HARDSUBEXTRACT = 'hardsubextract'
class FontFormat:
    def __init__(self, 
                 family: str = None,
                 size: float = 24,
                 stroke_width: float = 0,
                 frgb=(0, 0, 0),
                 srgb=(0, 0, 0),
                 bold: bool = False,
                 underline: bool = False,
                 italic: bool = False, 
                 alignment: int = 0,
                 vertical: bool = False, 
                 weight: int = 50, 
                 alpha: int = 255,
                 line_spacing: float = 1) -> None:
        self.family = family if family is not None else DEFAULT_FONT_FAMILY
        self.size = size
        self.stroke_width = stroke_width
        self.frgb = frgb                  # font color
        self.srgb = srgb                    # stroke color
        self.bold = bold
        self.underline = underline
        self.italic = italic
        self.alpha = alpha
        self.weight: int = weight
        self.alignment: int = alignment
        self.vertical: bool = vertical
        self.line_spacing = line_spacing

class Proj:
    def __init__(self) -> None:
        pass
    @staticmethod
    def load(proj_path: str):
        try:
            with open(proj_path, 'r', encoding='utf8') as f:
                proj_dict = json.loads(f.read())
        except Exception as e:
            raise ProjectLoadFailureException(e)
        proj_type = proj_dict['type']
        if proj_type == PROJTYPE_IMGTRANS:
            proj = ProjImgTrans()
        elif proj_type == PROJTYPE_HARDSUBEXTRACT:
            proj = ProjHardSubExtract()
        else:
            raise NotImplementedProjException(proj_type)
        proj.load_from_dict(proj_dict)
        return proj
    def from_textblock(self, text_block: TextBlock):
        self.family = text_block.font_family
        self.size = px2pt(text_block.font_size)
        self.stroke_width = text_block.stroke_width
        self.frgb, self.srgb = text_block.get_font_colors()
        self.bold = text_block.bold
        self.weight = text_block.font_weight
        self.underline = text_block.underline
        self.italic = text_block.italic
        self.alignment = text_block.alignment()
        self.vertical = text_block.vertical
        self.line_spacing = text_block.line_spacing


PROJTYPE_IMGTRANS = 'imgtrans'
PROJTYPE_HARDSUBEXTRACT = 'hardsubextract'

class ProjImgTrans:

    def __init__(self, directory: str = None):
@@ -375,30 +397,11 @@ class DLModuleConfig:
        if inpainter_setup_params is None:
            self.inpainter_setup_params = dict()
        else:
            inpainter_setup_params = inpainter_setup_params
            self.inpainter_setup_params = inpainter_setup_params
        self.translate_source = translate_source
        self.translate_target = translate_target
        self.check_need_inpaint = check_need_inpaint

    def load_from_dict(self, config_dict: dict):
        try:
            self.textdetector = config_dict['textdetector']
            self.inpainter = config_dict['inpainter']
            self.ocr = config_dict['ocr']
            self.translator = config_dict['translator']
            self.enable_ocr = config_dict['enable_ocr']
            self.enable_translate = config_dict['enable_translate']
            self.enable_inpaint = config_dict['enable_inpaint']
            self.translator_setup_params = config_dict['translator_setup_params']
            self.inpainter_setup_params = config_dict['inpainter_setup_params']
            self.textdetector_setup_params = config_dict['textdetector_setup_params']
            self.ocr_setup_params = config_dict['ocr_setup_params']
            self.translate_source = config_dict['translate_source']
            self.translate_target = config_dict['translate_target']
            self.check_need_inpaint = config_dict['check_need_inpaint']
        except Exception as e:
            raise InvalidProgramConfigException(e)

    def __getitem__(self, item: str):
        if item == 'textdetector':
            return self.textdetector
@@ -437,105 +440,53 @@ class DrawPanelConfig:
        self.rectool_auto = rectool_auto
        self.rectool_method = rectool_method


class ProgramConfig:
    def __init__(self, config_dict=None) -> None:
    def __init__(self, 
                 dl: Union[Dict, DLModuleConfig] = None,
                 drawpanel: Union[Dict, DrawPanelConfig] = None,
                 global_fontformat: Union[Dict, FontFormat] = None,
                 recent_proj_list: List[str] = list(),
                 imgtrans_paintmode: bool = False,
                 imgtrans_textedit: bool = True,
                 imgtrans_textblock: bool = True,
                 mask_transparency: float = 0.,
                 original_transparency: float = 0.,
                 open_recent_on_startup: bool = True, 
                 let_fntsize_flag: int = 0,
                 let_fntstroke_flag: int = 0,
                 let_fntcolor_flag: int = 0,
                 let_uppercase_flag: bool = True) -> None:
        if isinstance(dl, dict):
            self.dl = DLModuleConfig(**dl)
        elif dl is None:
            self.dl = DLModuleConfig()
        self.recent_proj_list: list = []
        self.imgtrans_paintmode = False
        self.imgtrans_textedit = True
        self.imgtrans_textblock = True
        self.mask_transparency = 0
        self.original_transparency = 0
        self.global_fontformat = FontFormat()
        else:
            self.dl = dl
        if isinstance(drawpanel, dict):
            self.drawpanel = DrawPanelConfig(**drawpanel)
        elif drawpanel is None:
            self.drawpanel = DrawPanelConfig()
        self.open_recent_on_startup = False
        self.let_fntsize_flag = 0
        self.let_fntstroke_flag = 0
        self.let_fntcolor_flag = 0
        if config_dict is not None:
            self.load_from_dict(config_dict)

    def load_from_dict(self, config_dict):
        try:
            # self.dl.load_from_dict(config_dict['dl'])
            self.dl.load_from_dict(config_dict['dl'])
            self.recent_proj_list = config_dict['recent_proj_list']
            self.imgtrans_paintmode = config_dict['imgtrans_paintmode']
            self.imgtrans_textedit = config_dict['imgtrans_textedit']
            self.imgtrans_textblock = config_dict['imgtrans_textblock']
            self.mask_transparency = config_dict['mask_transparency']
            self.original_transparency = config_dict['original_transparency']
            self.global_fontformat = FontFormat(**config_dict['global_fontformat'])
            self.drawpanel = DrawPanelConfig(**config_dict['drawpanel'])
            self.open_recent_on_startup = config_dict['open_recent_on_startup']
            self.let_fntsize_flag = config_dict['let_fntsize_flag']
            self.let_fntstroke_flag = config_dict['let_fntstroke_flag']
            self.let_fntcolor_flag = config_dict['let_fntcolor_flag']
        except Exception as e:
            raise InvalidProgramConfigException(e)

    def to_dict(self):
        return {
            'dl': vars(self.dl),
            'recent_proj_list': self.recent_proj_list,
            'imgtrans_textedit': self.imgtrans_textedit,
            'imgtrans_paintmode': self.imgtrans_paintmode,
            'imgtrans_textblock': self.imgtrans_textblock, 
            'global_fontformat': self.global_fontformat.to_dict(),
            'mask_transparency': self.mask_transparency,
            'original_transparency': self.original_transparency,
            'drawpanel': vars(self.drawpanel),
            'open_recent_on_startup': self.open_recent_on_startup,
            'let_fntsize_flag': self.let_fntsize_flag,
            'let_fntstroke_flag': self.let_fntstroke_flag,
            'let_fntcolor_flag': self.let_fntcolor_flag
        }


class FontFormat:
    def __init__(self, 
                 family: str = None,
                 size: float = 24,
                 stroke_width: float = 0,
                 frgb=(0, 0, 0),
                 srgb=(0, 0, 0),
                 bold: bool = False,
                 underline: bool = False,
                 italic: bool = False, 
                 alignment: int = 0,
                 vertical: bool = False, 
                 weight: int = 50, 
                 alpha: int = 255,
                 line_spacing: float = 1) -> None:
        self.family = family if family is not None else DEFAULT_FONT_FAMILY
        self.size = size
        self.stroke_width = stroke_width
        self.frgb = frgb                  # font color
        self.srgb = srgb                    # stroke color
        self.bold = bold
        self.underline = underline
        self.italic = italic
        self.alpha = alpha
        self.weight: int = weight
        self.alignment: int = alignment
        self.vertical: bool = vertical
        self.line_spacing = line_spacing

    def from_textblock(self, text_block: TextBlock):
        self.family = text_block.font_family
        self.size = px2pt(text_block.font_size)
        self.stroke_width = text_block.stroke_width
        self.frgb, self.srgb = text_block.get_font_colors()
        self.bold = text_block.bold
        self.weight = text_block.font_weight
        self.underline = text_block.underline
        self.italic = text_block.italic
        self.alignment = text_block.alignment()
        self.vertical = text_block.vertical
        self.line_spacing = text_block.line_spacing
        else:
            self.drawpanel = drawpanel
        if isinstance(global_fontformat, dict):
            self.global_fontformat = FontFormat(**global_fontformat)
        elif global_fontformat is None:
            self.global_fontformat = FontFormat()
        else:
            self.global_fontformat = global_fontformat
        self.recent_proj_list = recent_proj_list
        self.imgtrans_paintmode = imgtrans_paintmode
        self.imgtrans_textedit = imgtrans_textedit
        self.imgtrans_textblock = imgtrans_textblock
        self.mask_transparency = mask_transparency
        self.original_transparency = original_transparency
        self.open_recent_on_startup = open_recent_on_startup
        self.let_fntsize_flag = let_fntsize_flag
        self.let_fntstroke_flag = let_fntstroke_flag
        self.let_fntcolor_flag = let_fntcolor_flag
        self.let_uppercase_flag = let_uppercase_flag

    def to_dict(self):
        return vars(self)

span_pattern = re.compile(r'<span style=\"(.*?)\">', re.DOTALL)
p_pattern = re.compile(r'<p style=\"(.*?)\">', re.DOTALL)
+6 −0
Original line number Diff line number Diff line
@@ -8,6 +8,12 @@ NP_BOOL_TYPES = (np.bool_, np.bool8)
NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)

def to_dict(obj):
    return json.loads(json.dumps(obj, default=lambda o: o.__dict__, ensure_ascii=False))

def json_dump_nested_obj(obj):
    return json.dumps(obj, default=lambda o: o.__dict__, ensure_ascii=False)

# https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):