Unverified Commit 375290ba authored by dmMaze's avatar dmMaze Committed by GitHub
Browse files

Merge pull request #10 from dmMaze/lama

Add lama inpainter
parents 1ff39bca 8f9a03a9
Loading
Loading
Loading
Loading
+116 −13
Original line number Diff line number Diff line
@@ -2,17 +2,16 @@ import numpy as np
import cv2
from typing import Dict, List



from utils.registry import Registry
from utils.textblock_mask import canny_flood, connected_canny_flood, extract_ballon_mask
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..moduleparamparser import ModuleParamParser, DEFAULT_DEVICE
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
register_inpainter = INPAINTERS.register_module

from ..moduleparamparser import ModuleParamParser, DEFAULT_DEVICE
from ..textdetector import TextBlock

class InpainterBase(ModuleParamParser):

@@ -93,17 +92,10 @@ class PatchmatchInpainter(InpainterBase):

import torch
from utils.imgproc_utils import resize_keepasp
from .aot import AOTGenerator
from .aot import AOTGenerator, load_aot_model
AOTMODEL: AOTGenerator = None
AOTMODEL_PATH = 'data/models/aot_inpainter.ckpt'

def load_aot_model(model_path, device) -> AOTGenerator:
    model = AOTGenerator(in_ch=4, out_ch=3, ch=32, alpha=0.0)
    sd = torch.load(model_path, map_location = 'cpu')
    model.load_state_dict(sd['model'] if 'model' in sd else sd)
    model.eval().to(device)
    return model


@register_inpainter('aot')
class AOTInpainter(InpainterBase):
@@ -205,3 +197,114 @@ class AOTInpainter(InpainterBase):

        elif param_key == 'inpaint_size':
            self.inpaint_size = int(self.setup_params['inpaint_size']['select'])


from .lama import LamaFourier, load_lama_mpe

LAMA_MPE: LamaFourier = None
@register_inpainter('lama_mpe')
class LamaInpainterMPE(InpainterBase):

    setup_params = {
        'inpaint_size': {
            'type': 'selector',
            'options': [
                1024, 
                2048
            ], 
            'select': 2048
        }, 
        'device': {
            'type': 'selector',
            'options': [
                'cpu',
                'cuda'
            ],
            'select': DEFAULT_DEVICE
        }
    }

    device = DEFAULT_DEVICE
    inpaint_size = 2048

    def setup_inpainter(self):
        global LAMA_MPE

        self.device = self.setup_params['device']['select']
        if LAMA_MPE is None:
            self.model = LAMA_MPE = load_lama_mpe(r'data/models/lama_mpe.ckpt', self.device)
        else:
            self.model = LAMA_MPE
            self.model.to(self.device)
        self.inpaint_by_block = True if self.device == 'cuda' else False
        self.inpaint_size = int(self.setup_params['inpaint_size']['select'])

    def inpaint_preprocess(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:

        img_original = np.copy(img)
        mask_original = np.copy(mask)
        mask_original[mask_original < 127] = 0
        mask_original[mask_original >= 127] = 1
        mask_original = mask_original[:, :, None]

        new_shape = self.inpaint_size if max(img.shape[0: 2]) > self.inpaint_size else None
        # high resolution input could produce cloudy artifacts
        img = resize_keepasp(img, new_shape, stride=64)
        mask = resize_keepasp(mask, new_shape, stride=64)

        im_h, im_w = img.shape[:2]
        longer = max(im_h, im_w)
        pad_bottom = longer - im_h if im_h < longer else 0
        pad_right = longer - im_w if im_w < longer else 0
        mask = cv2.copyMakeBorder(mask, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)
        img = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)

        img_torch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze_(0).float() / 255.0
        mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0
        mask_torch[mask_torch < 0.5] = 0
        mask_torch[mask_torch >= 0.5] = 1
        rel_pos, _, direct = self.model.load_masked_position_encoding(mask_torch[0][0].numpy())
        rel_pos = torch.LongTensor(rel_pos).unsqueeze_(0)
        direct = torch.LongTensor(direct).unsqueeze_(0)

        if self.device == 'cuda':
            img_torch = img_torch.cuda()
            mask_torch = mask_torch.cuda()
            rel_pos = rel_pos.cuda()
            direct = direct.cuda()
        img_torch *= (1 - mask_torch)
        return img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right

    @torch.no_grad()
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:

        im_h, im_w = img.shape[:2]
        img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask)
        img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct)
        
        img_inpainted = (img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        if pad_bottom > 0:
            img_inpainted = img_inpainted[:-pad_bottom]
        if pad_right > 0:
            img_inpainted = img_inpainted[:, :-pad_right]
        new_shape = img_inpainted.shape[:2]
        if new_shape[0] != im_h or new_shape[1] != im_w :
            img_inpainted = cv2.resize(img_inpainted, (im_w, im_h), interpolation = cv2.INTER_LINEAR)
        img_inpainted = img_inpainted * mask_original + img_original * (1 - mask_original)
        
        return img_inpainted

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)

        if param_key == 'device':
            param_device = self.setup_params['device']['select']
            self.model.to(param_device)
            self.device = param_device
            if param_device == 'cuda':
                self.inpaint_by_block = False
            else:
                self.inpaint_by_block = True

        elif param_key == 'inpaint_size':
            self.inpaint_size = int(self.setup_params['inpaint_size']['select'])
 No newline at end of file
+6 −112
Original line number Diff line number Diff line
@@ -251,115 +251,9 @@ class AOTGenerator(nn.Module) :
		else :
			return torch.clip(x, -1, 1)

# class AOTInpainterTorch:
# 	def __init__(self, model_path: str, device: str = 'cpu'):
# 		self.device = device
# 		self.net = AOTGenerator(in_ch=4, out_ch=3, ch=32, alpha=0.0)
# 		sd = torch.load(model_path, map_location = 'cpu')
# 		self.net.load_state_dict(sd['model'] if 'model' in sd else sd)
# 		self.net.eval().to(device)
    
# 	def inpaint_preprocess(self, img: np.ndarray, mask: np.ndarray, inpaint_size: int = 1024) -> np.ndarray:

# 		pad_size = 4
# 		img_original = np.copy(img)
# 		mask_original = np.copy(mask)
# 		mask_original[mask_original < 127] = 0
# 		mask_original[mask_original >= 127] = 1
# 		mask_original = mask_original[:, :, None]
# 		h, w, c = img.shape
# 		new_shape = inpaint_size if max(img.shape[0: 2]) > inpaint_size else None

# 		img = resize_keepasp(img, new_shape, stride=None)
# 		mask = resize_keepasp(mask, new_shape, stride=None)

# 		im_h, im_w = img.shape[:2]
# 		pad_bottom = 128 - im_h if im_h < 128 else 0
# 		pad_right = 128 - im_w if im_w < 128 else 0
# 		mask = cv2.copyMakeBorder(mask, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)
# 		img = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)

# 		img_torch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze_(0).float() / 127.5 - 1.0
# 		mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0
# 		mask_torch[mask_torch < 0.5] = 0
# 		mask_torch[mask_torch >= 0.5] = 1

# 		if self.device == 'cuda':
# 			img_torch = img_torch.cuda()
# 			mask_torch = mask_torch.cuda()
# 		img_torch *= (1 - mask_torch)
# 		return img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right

# 	@torch.no_grad()
# 	def __call__(self, img: np.ndarray, mask: np.ndarray, inpaint_size: int = 1024) -> np.ndarray:
# 		im_h, im_w = img.shape[:2]
# 		img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask, inpaint_size)
# 		img_inpainted_torch = self.net(img_torch, mask_torch)
# 		img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8)

# 		if pad_bottom > 0:
# 			img_inpainted = img_inpainted[:-pad_bottom]
# 		if pad_right > 0:
# 			img_inpainted = img_inpainted[:, :-pad_right]

# 		new_shape = img_inpainted.shape[:2]
# 		if new_shape[0] != im_h or new_shape[1] != im_w :
# 			img_inpainted = cv2.resize(img_inpainted, (im_w, im_h), interpolation = cv2.INTER_LINEAR)
		
# 		img_inpainted = img_inpainted * mask_original + img_original * (1 - mask_original)

# 		return img_inpainted


# def dispatch(use_inpainting: bool, use_poisson_blending: bool, cuda: bool, img: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, model_name: str = 'default', verbose: bool = False) -> np.ndarray :
# 	img_original = np.copy(img)
# 	mask_original = np.copy(mask)
# 	mask_original[mask_original < 127] = 0
# 	mask_original[mask_original >= 127] = 1
# 	mask_original = mask_original[:, :, None]
# 	if not use_inpainting :
# 		img = np.copy(img)
# 		img[mask > 0] = np.array([255, 255, 255], np.uint8)
# 		if verbose :
# 			return img, img
# 		else :
# 			return img
# 	height, width, c = img.shape
# 	if max(img.shape[0: 2]) > inpainting_size :
# 		img = resize_keep_aspect(img, inpainting_size)
# 		mask = resize_keep_aspect(mask, inpainting_size)
# 	pad_size = 4
# 	h, w, c = img.shape
# 	if h % pad_size != 0 :
# 		new_h = (pad_size - (h % pad_size)) + h
# 	else :
# 		new_h = h
# 	if w % pad_size != 0 :
# 		new_w = (pad_size - (w % pad_size)) + w
# 	else :
# 		new_w = w
# 	if new_h != h or new_w != w :
# 		img = cv2.resize(img, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
# 		mask = cv2.resize(mask, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
# 	if verbose :
# 		print(f'Inpainting resolution: {new_w}x{new_h}')
# 	img_torch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze_(0).float() / 127.5 - 1.0
# 	mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0
# 	mask_torch[mask_torch < 0.5] = 0
# 	mask_torch[mask_torch >= 0.5] = 1
# 	if cuda :
# 		img_torch = img_torch.cuda()
# 		mask_torch = mask_torch.cuda()
# 	with torch.no_grad() :
# 		img_torch *= (1 - mask_torch)
# 		img_inpainted_torch = DEFAULT_MODEL(img_torch, mask_torch)
# 	img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8)
# 	if new_h != height or new_w != width :
# 		img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR)
# 	if use_poisson_blending :
# 		raise NotImplemented
# 	else :
# 		ans = img_inpainted * mask_original + img_original * (1 - mask_original)
# 	if verbose :
# 		return ans, (img_torch.cpu() * 127.5 + 127.5).squeeze_(0).permute(1, 2, 0).numpy()
# 	return ans
 No newline at end of file
def load_aot_model(model_path, device) -> AOTGenerator:
    model = AOTGenerator(in_ch=4, out_ch=3, ch=32, alpha=0.0)
    sd = torch.load(model_path, map_location = 'cpu')
    model.load_state_dict(sd['model'] if 'model' in sd else sd)
    model.eval().to(device)
    return model
 No newline at end of file
+296 −0
Original line number Diff line number Diff line
# Fast Fourier Convolution NeurIPS 2020
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf



import torch
import torch.nn as nn
import torch.nn.functional as F


class FFCSE_block(nn.Module):

    def __init__(self, channels, ratio_g):
        super(FFCSE_block, self).__init__()
        in_cg = int(channels * ratio_g)
        in_cl = channels - in_cg
        r = 16

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv1 = nn.Conv2d(channels, channels // r,
                               kernel_size=1, bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
            channels // r, in_cl, kernel_size=1, bias=True)
        self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
            channels // r, in_cg, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x if type(x) is tuple else (x, 0)
        id_l, id_g = x

        x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
        x = self.avgpool(x)
        x = self.relu1(self.conv1(x))

        x_l = 0 if self.conv_a2l is None else id_l * \
            self.sigmoid(self.conv_a2l(x))
        x_g = 0 if self.conv_a2g is None else id_g * \
            self.sigmoid(self.conv_a2g(x))
        return x_l, x_g


class FourierUnit(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
                 spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
        # bn_layer not used
        super(FourierUnit, self).__init__()
        self.groups = groups

        self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
                                          out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
        self.bn = torch.nn.BatchNorm2d(out_channels * 2)
        self.relu = torch.nn.ReLU(inplace=True)

        # squeeze and excitation block
        self.use_se = use_se
        # if use_se:
        #     if se_kwargs is None:
        #         se_kwargs = {}
        #     self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)

        self.spatial_scale_factor = spatial_scale_factor
        self.spatial_scale_mode = spatial_scale_mode
        self.spectral_pos_encoding = spectral_pos_encoding
        self.ffc3d = ffc3d
        self.fft_norm = fft_norm

    def forward(self, x):
        batch = x.shape[0]

        if self.spatial_scale_factor is not None:
            orig_size = x.shape[-2:]
            x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)

        r_size = x.size()
        # (batch, c, h, w/2+1, 2)
        fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
        # x: torch.float16
        if x.dtype == torch.float16:
            half = True
            x = x.type(torch.float32)
        else:
            half = False
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        if self.spectral_pos_encoding:
            height, width = ffted.shape[-2:]
            coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
            coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
            ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)

        if self.use_se:
            ffted = self.se(ffted)

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        if ffted.dtype == torch.float16:
            ffted = ffted.type(torch.float32)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])

        ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

        if self.spatial_scale_factor is not None:
            output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)

        return output


class SpectralTransform(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
        # bn_layer not used
        super(SpectralTransform, self).__init__()
        self.enable_lfu = enable_lfu
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels //
                      2, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(inplace=True)
        )
        self.fu = FourierUnit(
            out_channels // 2, out_channels // 2, groups, **fu_kwargs)
        if self.enable_lfu:
            self.lfu = FourierUnit(
                out_channels // 2, out_channels // 2, groups)
        self.conv2 = torch.nn.Conv2d(
            out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)

    def forward(self, x):

        x = self.downsample(x)
        x = self.conv1(x)
        output = self.fu(x)

        if self.enable_lfu:
            n, c, h, w = x.shape
            split_no = 2
            split_s = h // split_no
            xs = torch.cat(torch.split(
                x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
            xs = torch.cat(torch.split(xs, split_s, dim=-1),
                           dim=1).contiguous()
            xs = self.lfu(xs)
            xs = xs.repeat(1, 1, split_no, split_no).contiguous()
        else:
            xs = 0

        output = self.conv2(x + output + xs)

        return output


class FFC(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin, ratio_gout, stride=1, padding=0,
                 dilation=1, groups=1, bias=False, enable_lfu=True,
                 padding_type='reflect', gated=False, **spectral_kwargs):
        super(FFC, self).__init__()

        assert stride == 1 or stride == 2, "Stride should be 1 or 2."
        self.stride = stride

        in_cg = int(in_channels * ratio_gin)
        in_cl = in_channels - in_cg
        out_cg = int(out_channels * ratio_gout)
        out_cl = out_channels - out_cg
        #groups_g = 1 if groups == 1 else int(groups * ratio_gout)
        #groups_l = 1 if groups == 1 else groups - groups_g

        self.ratio_gin = ratio_gin
        self.ratio_gout = ratio_gout
        self.global_in_num = in_cg

        module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
        self.convl2l = module(in_cl, out_cl, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
        self.convl2g = module(in_cl, out_cg, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
        self.convg2l = module(in_cg, out_cl, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
        self.convg2g = module(
            in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)

        self.gated = gated
        module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
        self.gate = module(in_channels, 2, 1)

    def forward(self, x):
        x_l, x_g = x if type(x) is tuple else (x, 0)
        out_xl, out_xg = 0, 0

        if self.gated:
            total_input_parts = [x_l]
            if torch.is_tensor(x_g):
                total_input_parts.append(x_g)
            total_input = torch.cat(total_input_parts, dim=1)

            gates = torch.sigmoid(self.gate(total_input))
            g2l_gate, l2g_gate = gates.chunk(2, dim=1)
        else:
            g2l_gate, l2g_gate = 1, 1

        if self.ratio_gout != 1:
            out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
        if self.ratio_gout != 0:
            out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)

        return out_xl, out_xg


class FFC_BN_ACT(nn.Module):

    def __init__(self, in_channels, out_channels,
                 kernel_size, ratio_gin, ratio_gout,
                 stride=1, padding=0, dilation=1, groups=1, bias=False,
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
                 padding_type='reflect',
                 enable_lfu=True, **kwargs):
        super(FFC_BN_ACT, self).__init__()
        self.ffc = FFC(in_channels, out_channels, kernel_size,
                       ratio_gin, ratio_gout, stride, padding, dilation,
                       groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
        lnorm = nn.Identity if ratio_gout == 1 else norm_layer
        gnorm = nn.Identity if ratio_gout == 0 else norm_layer
        global_channels = int(out_channels * ratio_gout)
        self.bn_l = lnorm(out_channels - global_channels)
        self.bn_g = gnorm(global_channels)

        lact = nn.Identity if ratio_gout == 1 else activation_layer
        gact = nn.Identity if ratio_gout == 0 else activation_layer
        self.act_l = lact(inplace=True)
        self.act_g = gact(inplace=True)

    def forward(self, x):
        x_l, x_g = self.ffc(x)
        x_l = self.act_l(self.bn_l(x_l))
        x_g = self.act_g(self.bn_g(x_g))
        return x_l, x_g


class FFCResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
                 spatial_transform_kwargs=None, inline=False, **conv_kwargs):
        super().__init__()
        self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
                                norm_layer=norm_layer,
                                activation_layer=activation_layer,
                                padding_type=padding_type,
                                **conv_kwargs)
        self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
                                norm_layer=norm_layer,
                                activation_layer=activation_layer,
                                padding_type=padding_type,
                                **conv_kwargs)
        # if spatial_transform_kwargs is not None:
        #     self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
        #     self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
        self.inline = inline

    def forward(self, x):
        if self.inline:
            x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
        else:
            x_l, x_g = x if type(x) is tuple else (x, 0)

        id_l, id_g = x_l, x_g

        x_l, x_g = self.conv1((x_l, x_g))
        x_l, x_g = self.conv2((x_l, x_g))

        x_l, x_g = id_l + x_l, id_g + x_g
        out = x_l, x_g
        if self.inline:
            out = torch.cat(out, dim=1)
        return out
 No newline at end of file
+419 −0

File added.

Preview size limit exceeded, changes collapsed.

+6 −2
Original line number Diff line number Diff line
import time
from typing import Union
import numpy as np
import traceback


from PyQt5.QtCore import QThread, pyqtSignal, QObject, QLocale
from PyQt5.QtWidgets import QMessageBox
@@ -48,7 +50,8 @@ class ModuleThread(QThread):
        except Exception as e:
            self.module = old_module
            msg = self.tr('Failed to set ') + module_name
            self.exception_occurred.emit(msg, str(e))
            
            self.exception_occurred.emit(msg, str(e) + '\n' + f'exc: {traceback.format_exc()}')
        self.finish_set_module.emit()

    def pipeline_finished(self):
@@ -103,7 +106,8 @@ class InpaintThread(ModuleThread):
            }
            self.finish_inpaint.emit(inpaint_dict)
        except Exception as e:
            self.exception_occurred.emit(self.tr('Inpainting Failed.'), repr(e))
            # self.exception_occurred.emit(self.tr('Inpainting Failed.'), repr(e))
            self.exception_occurred.emit(self.tr('Inpainting Failed.'), str(e) + '\n' + f'exc: {traceback.format_exc()}')


class TextDetectThread(ModuleThread):
Loading