Commit 28e6eb5c authored by dmMaze's avatar dmMaze
Browse files

fix pe length for mit models

parent 95953b22
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ class CustomTransformerEncoderLayer(nn.Module):
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.pe = PositionalEncoding(d_model, max_len = 2048)
        self.pe = PositionalEncoding(d_model, max_len = 3072)

        self.activation = F.gelu

+1 −67
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ from utils.message import create_error_dialog, create_info_dialog
from .custom_widget import ImgtransProgressMessageBox, ParamComboBox
from .configpanel import ConfigPanel
from utils.proj_imgtrans import ProjImgTrans
from utils.config import pcfg
from utils.config import pcfg, merge_config_module_params
cfg_module = pcfg.module


@@ -512,72 +512,6 @@ class ImgtransThread(QThread):
        return ref_counter - 1


def merge_config_module_params(config_params: Dict, module_keys: List, get_module: Callable) -> Dict:
    for module_key in module_keys:
        module_params = get_module(module_key).params
        if module_key not in config_params or config_params[module_key] is None:
            config_params[module_key] = module_params
        else:
            cfg_param = config_params[module_key]
            cfg_key_set = set(cfg_param.keys())
            module_key_set = set(module_params.keys())
            for ck in cfg_key_set:
                if ck not in module_key_set:
                    LOGGER.warning(f'Found invalid {module_key} config: {ck}')
                    cfg_param.pop(ck)

            for mk in module_key_set:
                if mk not in cfg_key_set:
                    # LOGGER.info(f'Found new {module_key} config: {mk}')
                    cfg_param[mk] = module_params[mk]
                else:
                    mparam = module_params[mk]
                    cparam = cfg_param[mk]
                    if isinstance(mparam, dict):
                        tgt_type = type(mparam['value'])
                        if isinstance(cparam, dict):
                            if 'value' in cparam:
                                v = cparam['value']
                            elif isinstance(mparam['value'], dict):
                                for k in mparam['value']:
                                    if k in cparam:
                                        mparam['value'][k] = cparam[k]
                                v = mparam['value']
                            else:
                                v = mparam['value']
                        else:
                            v = cparam
                        valid = True
                        if tgt_type != type(v):
                            try:
                                v = tgt_type(v)
                            except:
                                valid = False
                                LOGGER.warning(f'Invalid param value {v} for defined dtype: {tgt_type}, it will be set to default value: {mparam}')
                        if valid:
                            mparam['value'] = v
                        cfg_param[mk] = mparam
                    else:
                        if type(cparam) != type(mparam):
                            if not isinstance(mparam, dict) and isinstance(cparam, dict):
                                cparam = cparam['value']
                            try:
                                cfg_param[mk] = type(mparam)(cparam)
                            except ValueError:
                                LOGGER.warning(f'Invalid param value {cparam} for defined dtype: {type(mparam)}, it will be set to default value: {mparam}')
                                cfg_param[mk] = mparam
            
            cfg_key_list = list(cfg_param.keys())
            module_key_list = list(module_params.keys())
            if cfg_key_list != module_key_list:
                LOGGER.info(f'Reorder param dict in config')
                new_params = {key: cfg_param[key] for key in module_key_list}
                cfg_param.clear()
                cfg_param.update(new_params)

    return config_params


def unload_modules(self, module_names):
    model_deleted = False
    for module in module_names:
+68 −1
Original line number Diff line number Diff line
import json, os, traceback
import os.path as osp
import copy
from typing import Callable

from . import shared
from .fontformat import FontFormat
@@ -288,3 +289,69 @@ def save_text_styles(raise_exception = False):
    os.replace(tmp_save_tgt, pcfg.text_styles_path)
    LOGGER.info('Text style saved')
    return True


def merge_config_module_params(config_params: Dict, module_keys: List, get_module: Callable) -> Dict:
    for module_key in module_keys:
        module_params = get_module(module_key).params
        if module_key not in config_params or config_params[module_key] is None:
            config_params[module_key] = module_params
        else:
            cfg_param = config_params[module_key]
            cfg_key_set = set(cfg_param.keys())
            module_key_set = set(module_params.keys())
            for ck in cfg_key_set:
                if ck not in module_key_set:
                    LOGGER.warning(f'Found invalid {module_key} config: {ck}')
                    cfg_param.pop(ck)

            for mk in module_key_set:
                if mk not in cfg_key_set:
                    # LOGGER.info(f'Found new {module_key} config: {mk}')
                    cfg_param[mk] = module_params[mk]
                else:
                    mparam = module_params[mk]
                    cparam = cfg_param[mk]
                    if isinstance(mparam, dict):
                        tgt_type = type(mparam['value'])
                        if isinstance(cparam, dict):
                            if 'value' in cparam:
                                v = cparam['value']
                            elif isinstance(mparam['value'], dict):
                                for k in mparam['value']:
                                    if k in cparam:
                                        mparam['value'][k] = cparam[k]
                                v = mparam['value']
                            else:
                                v = mparam['value']
                        else:
                            v = cparam
                        valid = True
                        if tgt_type != type(v):
                            try:
                                v = tgt_type(v)
                            except:
                                valid = False
                                LOGGER.warning(f'Invalid param value {v} for defined dtype: {tgt_type}, it will be set to default value: {mparam}')
                        if valid:
                            mparam['value'] = v
                        cfg_param[mk] = mparam
                    else:
                        if type(cparam) != type(mparam):
                            if not isinstance(mparam, dict) and isinstance(cparam, dict):
                                cparam = cparam['value']
                            try:
                                cfg_param[mk] = type(mparam)(cparam)
                            except ValueError:
                                LOGGER.warning(f'Invalid param value {cparam} for defined dtype: {type(mparam)}, it will be set to default value: {mparam}')
                                cfg_param[mk] = mparam
            
            cfg_key_list = list(cfg_param.keys())
            module_key_list = list(module_params.keys())
            if cfg_key_list != module_key_list:
                LOGGER.info(f'Reorder param dict in config')
                new_params = {key: cfg_param[key] for key in module_key_list}
                cfg_param.clear()
                cfg_param.update(new_params)

    return config_params