Commit 046bc0e1 authored by dmMaze's avatar dmMaze
Browse files

Convert config param value to defined data type automatically

parent 3ba3d619
Loading
Loading
Loading
Loading
+27 −6
Original line number Diff line number Diff line
@@ -65,13 +65,34 @@ class BaseModule:
        assert cls._preprocess_hooks is not None
        register_hooks(cls._preprocess_hooks, callbacks)

    def updateParam(self, param_key: str, param_content):
        self_param_content = self.params[param_key]
        if isinstance(self_param_content, (str, float, int)):
            self.params[param_key] = param_content
    def get_param_value(self, param_key: str):
        assert self.params is not None and param_key in self.params
        p = self.params[param_key]
        if isinstance(p, dict):
            return p['value']
        return p
    
    def set_param_value(self, param_key: str, param_value, convert_dtype=True):
        assert self.params is not None and param_key in self.params
        p = self.params[param_key]
        if isinstance(p, dict):
            if convert_dtype:
                try:
                    param_value = type(p['value'])(param_value)
                except ValueError:
                    dtype = type(p['value'])
                    self.logger.warning(f'Invalid param value {param_value} for defined dtype: {dtype}')
            p['value'] = param_value
        else:
            param_dict = self.params[param_key]
            param_dict['value'] = param_content
            if convert_dtype:
                try:
                    param_value = type(p)(param_value)
                except ValueError:
                    self.logger.warning(f'Invalid param value {param_value} for defined dtype: {type(p)}')
            self.params[param_key] = param_value

    def updateParam(self, param_key: str, param_content):
        self.set_param_value(param_key, param_content)

    def is_cpu_intensive(self)->bool:
        if self.params is not None and 'device' in self.params:
+5 −5
Original line number Diff line number Diff line
@@ -109,23 +109,23 @@ class GPTTranslator(BaseTranslator):

    @property
    def temperature(self) -> float:
        return float(self.params['temperature'])
        return self.params['temperature']
    
    @property
    def max_tokens(self) -> int:
        return int(self.params['max tokens'])
        return self.params['max tokens']
    
    @property
    def top_p(self) -> int:
        return int(self.params['top p'])
        return self.params['top p']
    
    @property
    def retry_attempts(self) -> int:
        return int(self.params['retry attempts'])
        return self.params['retry attempts']
    
    @property
    def retry_timeout(self) -> int:
        return int(self.params['retry timeout'])
        return self.params['retry timeout']
    
    @property
    def chat_system_template(self) -> str:
+3 −3
Original line number Diff line number Diff line
@@ -239,15 +239,15 @@ class SakuraTranslator(BaseTranslator):

    @property
    def max_tokens(self) -> int:
        return int(self.params['max tokens'])
        return self.params['max tokens']

    @property
    def timeout(self) -> int:
        return int(self.params['timeout'])
        return self.params['timeout']

    @property
    def retry_attempts(self) -> int:
        return int(self.params['retry attempts'])
        return self.params['retry attempts']

    @property
    def api_base_raw(self) -> str:
+17 −0
Original line number Diff line number Diff line
@@ -123,6 +123,7 @@ class InpaintThread(ModuleThread):
            self.inpaint_failed.emit()
        self.inpainting = False


class TextDetectThread(ModuleThread):
    
    finish_detect_page = Signal(str)
@@ -531,6 +532,22 @@ def merge_config_module_params(config_params: Dict, module_keys: List, get_modul
                                continue
                            if k not in mparam:
                                cparam.pop(k)
                        if type(cparam['value']) != type(mparam['value']):
                            try:
                                cparam['value'] = type(mparam['value'])(cparam['value'])
                            except:
                                dtype = type(mparam['value'])
                                mv = mparam['value']
                                cv = cparam['value']
                                LOGGER.warning(f'Invalid param value {cv} for defined dtype: {dtype}, it will be set to default value: {mv}')
                                cparam['value'] = mv
                    else:
                        if type(cparam) != type(mparam):
                            try:
                                cfg_param[mk] = type(mparam)(cparam)
                            except ValueError:
                                LOGGER.warning(f'Invalid param value {cparam} for defined dtype: {type(mparam)}, it will be set to default value: {mparam}')
                                cfg_param[mk] = mparam
            
            cfg_key_list = list(cfg_param.keys())
            module_key_list = list(module_params.keys())