Commit 8f9a03a9 authored by dmMaze's avatar dmMaze
Browse files

rm lama(no mpe)

parent 91803427
Loading
Loading
Loading
Loading
+24 −36
Original line number Diff line number Diff line
@@ -2,17 +2,16 @@ import numpy as np
import cv2
from typing import Dict, List



from utils.registry import Registry
from utils.textblock_mask import canny_flood, connected_canny_flood, extract_ballon_mask
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..moduleparamparser import ModuleParamParser, DEFAULT_DEVICE
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
register_inpainter = INPAINTERS.register_module

from ..moduleparamparser import ModuleParamParser, DEFAULT_DEVICE
from ..textdetector import TextBlock

class InpainterBase(ModuleParamParser):

@@ -93,17 +92,10 @@ class PatchmatchInpainter(InpainterBase):

import torch
from utils.imgproc_utils import resize_keepasp
from .aot import AOTGenerator
from .aot import AOTGenerator, load_aot_model
AOTMODEL: AOTGenerator = None
AOTMODEL_PATH = 'data/models/aot_inpainter.ckpt'

def load_aot_model(model_path, device) -> AOTGenerator:
    model = AOTGenerator(in_ch=4, out_ch=3, ch=32, alpha=0.0)
    sd = torch.load(model_path, map_location = 'cpu')
    model.load_state_dict(sd['model'] if 'model' in sd else sd)
    model.eval().to(device)
    return model


@register_inpainter('aot')
class AOTInpainter(InpainterBase):
@@ -207,20 +199,11 @@ class AOTInpainter(InpainterBase):
            self.inpaint_size = int(self.setup_params['inpaint_size']['select'])


from .lama import LamaFourier
LAMAMODEL: LamaFourier = None
LAMAMODEL_PATH = 'data/models/accum4_l1m10_448px_last.ckpt'

def load_lama_model(model_path, device) -> LamaFourier:
    model = LamaFourier(build_discriminator=False)
    sd = torch.load(model_path, map_location = 'cpu')
    model.generator.load_state_dict(sd['gen_state_dict'] if 'gen_state_dict' in sd else sd)
    model.eval().to(device)
    return model
from .lama import LamaFourier, load_lama_mpe


@register_inpainter('lama')
class LamaInpainter(InpainterBase):
LAMA_MPE: LamaFourier = None
@register_inpainter('lama_mpe')
class LamaInpainterMPE(InpainterBase):

    setup_params = {
        'inpaint_size': {
@@ -243,15 +226,15 @@ class LamaInpainter(InpainterBase):

    device = DEFAULT_DEVICE
    inpaint_size = 2048
    LAMAMODEL: LamaFourier = None

    def setup_inpainter(self):
        global LAMAMODEL
        global LAMA_MPE

        self.device = self.setup_params['device']['select']
        if LAMAMODEL is None:
            self.model = LAMAMODEL = load_lama_model(LAMAMODEL_PATH, self.device)
        if LAMA_MPE is None:
            self.model = LAMA_MPE = load_lama_mpe(r'data/models/lama_mpe.ckpt', self.device)
        else:
            self.model = LAMAMODEL
            self.model = LAMA_MPE
            self.model.to(self.device)
        self.inpaint_by_block = True if self.device == 'cuda' else False
        self.inpaint_size = int(self.setup_params['inpaint_size']['select'])
@@ -265,7 +248,7 @@ class LamaInpainter(InpainterBase):
        mask_original = mask_original[:, :, None]

        new_shape = self.inpaint_size if max(img.shape[0: 2]) > self.inpaint_size else None
        new_shape = 640
        # high resolution input could produce cloudy artifacts
        img = resize_keepasp(img, new_shape, stride=64)
        mask = resize_keepasp(mask, new_shape, stride=64)

@@ -280,19 +263,24 @@ class LamaInpainter(InpainterBase):
        mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0
        mask_torch[mask_torch < 0.5] = 0
        mask_torch[mask_torch >= 0.5] = 1
        rel_pos, _, direct = self.model.load_masked_position_encoding(mask_torch[0][0].numpy())
        rel_pos = torch.LongTensor(rel_pos).unsqueeze_(0)
        direct = torch.LongTensor(direct).unsqueeze_(0)

        if self.device == 'cuda':
            img_torch = img_torch.cuda()
            mask_torch = mask_torch.cuda()
            rel_pos = rel_pos.cuda()
            direct = direct.cuda()
        img_torch *= (1 - mask_torch)
        return img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right
        return img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right

    @torch.no_grad()
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:

        im_h, im_w = img.shape[:2]
        img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask)
        img_inpainted_torch = self.model(img_torch, mask_torch)
        img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask)
        img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct)
        
        img_inpainted = (img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        if pad_bottom > 0:
+6 −112
Original line number Diff line number Diff line
@@ -251,115 +251,9 @@ class AOTGenerator(nn.Module) :
		else :
			return torch.clip(x, -1, 1)

# class AOTInpainterTorch:
# 	def __init__(self, model_path: str, device: str = 'cpu'):
# 		self.device = device
# 		self.net = AOTGenerator(in_ch=4, out_ch=3, ch=32, alpha=0.0)
# 		sd = torch.load(model_path, map_location = 'cpu')
# 		self.net.load_state_dict(sd['model'] if 'model' in sd else sd)
# 		self.net.eval().to(device)
    
# 	def inpaint_preprocess(self, img: np.ndarray, mask: np.ndarray, inpaint_size: int = 1024) -> np.ndarray:

# 		pad_size = 4
# 		img_original = np.copy(img)
# 		mask_original = np.copy(mask)
# 		mask_original[mask_original < 127] = 0
# 		mask_original[mask_original >= 127] = 1
# 		mask_original = mask_original[:, :, None]
# 		h, w, c = img.shape
# 		new_shape = inpaint_size if max(img.shape[0: 2]) > inpaint_size else None

# 		img = resize_keepasp(img, new_shape, stride=None)
# 		mask = resize_keepasp(mask, new_shape, stride=None)

# 		im_h, im_w = img.shape[:2]
# 		pad_bottom = 128 - im_h if im_h < 128 else 0
# 		pad_right = 128 - im_w if im_w < 128 else 0
# 		mask = cv2.copyMakeBorder(mask, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)
# 		img = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)

# 		img_torch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze_(0).float() / 127.5 - 1.0
# 		mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0
# 		mask_torch[mask_torch < 0.5] = 0
# 		mask_torch[mask_torch >= 0.5] = 1

# 		if self.device == 'cuda':
# 			img_torch = img_torch.cuda()
# 			mask_torch = mask_torch.cuda()
# 		img_torch *= (1 - mask_torch)
# 		return img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right

# 	@torch.no_grad()
# 	def __call__(self, img: np.ndarray, mask: np.ndarray, inpaint_size: int = 1024) -> np.ndarray:
# 		im_h, im_w = img.shape[:2]
# 		img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask, inpaint_size)
# 		img_inpainted_torch = self.net(img_torch, mask_torch)
# 		img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8)

# 		if pad_bottom > 0:
# 			img_inpainted = img_inpainted[:-pad_bottom]
# 		if pad_right > 0:
# 			img_inpainted = img_inpainted[:, :-pad_right]

# 		new_shape = img_inpainted.shape[:2]
# 		if new_shape[0] != im_h or new_shape[1] != im_w :
# 			img_inpainted = cv2.resize(img_inpainted, (im_w, im_h), interpolation = cv2.INTER_LINEAR)
		
# 		img_inpainted = img_inpainted * mask_original + img_original * (1 - mask_original)

# 		return img_inpainted


# def dispatch(use_inpainting: bool, use_poisson_blending: bool, cuda: bool, img: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, model_name: str = 'default', verbose: bool = False) -> np.ndarray :
# 	img_original = np.copy(img)
# 	mask_original = np.copy(mask)
# 	mask_original[mask_original < 127] = 0
# 	mask_original[mask_original >= 127] = 1
# 	mask_original = mask_original[:, :, None]
# 	if not use_inpainting :
# 		img = np.copy(img)
# 		img[mask > 0] = np.array([255, 255, 255], np.uint8)
# 		if verbose :
# 			return img, img
# 		else :
# 			return img
# 	height, width, c = img.shape
# 	if max(img.shape[0: 2]) > inpainting_size :
# 		img = resize_keep_aspect(img, inpainting_size)
# 		mask = resize_keep_aspect(mask, inpainting_size)
# 	pad_size = 4
# 	h, w, c = img.shape
# 	if h % pad_size != 0 :
# 		new_h = (pad_size - (h % pad_size)) + h
# 	else :
# 		new_h = h
# 	if w % pad_size != 0 :
# 		new_w = (pad_size - (w % pad_size)) + w
# 	else :
# 		new_w = w
# 	if new_h != h or new_w != w :
# 		img = cv2.resize(img, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
# 		mask = cv2.resize(mask, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
# 	if verbose :
# 		print(f'Inpainting resolution: {new_w}x{new_h}')
# 	img_torch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze_(0).float() / 127.5 - 1.0
# 	mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0
# 	mask_torch[mask_torch < 0.5] = 0
# 	mask_torch[mask_torch >= 0.5] = 1
# 	if cuda :
# 		img_torch = img_torch.cuda()
# 		mask_torch = mask_torch.cuda()
# 	with torch.no_grad() :
# 		img_torch *= (1 - mask_torch)
# 		img_inpainted_torch = DEFAULT_MODEL(img_torch, mask_torch)
# 	img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8)
# 	if new_h != height or new_w != width :
# 		img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR)
# 	if use_poisson_blending :
# 		raise NotImplemented
# 	else :
# 		ans = img_inpainted * mask_original + img_original * (1 - mask_original)
# 	if verbose :
# 		return ans, (img_torch.cpu() * 127.5 + 127.5).squeeze_(0).permute(1, 2, 0).numpy()
# 	return ans
 No newline at end of file
def load_aot_model(model_path, device) -> AOTGenerator:
    model = AOTGenerator(in_ch=4, out_ch=3, ch=32, alpha=0.0)
    sd = torch.load(model_path, map_location = 'cpu')
    model.load_state_dict(sd['model'] if 'model' in sd else sd)
    model.eval().to(device)
    return model
 No newline at end of file
+173 −8
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ import torch
import torch.nn as nn
from torch import Tensor
import numpy as np
import cv2

from .ffc import FFC_BN_ACT

def get_activation(kind='tanh'):
@@ -117,9 +119,17 @@ class FFCResNetGenerator(nn.Module):
            model.append(get_activation('tanh' if add_out_act is True else add_out_act))
        self.model = nn.Sequential(*model)

    def forward(self, img, mask) -> Tensor:
    def forward(self, img, mask, rel_pos=None, direct=None) -> Tensor:
        masked_img = torch.cat([img * (1 - mask), mask], dim=1)
        if rel_pos is None:
            return self.model(masked_img)
        else:
            
            x_l, x_g = self.model[:2](masked_img)
            x_l = x_l.to(torch.float32)
            x_l += rel_pos
            x_l += direct
            return self.model[2:]((x_l, x_g))


class NLayerDiscriminator(nn.Module):
@@ -176,8 +186,72 @@ def set_requires_grad(module, value):
    for param in module.parameters():
        param.requires_grad = value


class MaskedSinusoidalPositionalEmbedding(nn.Embedding):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, num_embeddings: int, embedding_dim: int):
        super().__init__(num_embeddings, embedding_dim)
        self.weight = self._init_weight(self.weight)

    @staticmethod
    def _init_weight(out: nn.Parameter):
        """
        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        n_pos, dim = out.shape
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        out.requires_grad = False  # set early to avoid an error in pytorch-1.8+
        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
        out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        out.detach_()
        return out

    @torch.no_grad()
    def forward(self, input_ids):
        """`input_ids` is expected to be [bsz x seqlen]."""
        return super().forward(input_ids)


class MultiLabelEmbedding(nn.Module):
    def __init__(self, num_positions: int, embedding_dim: int):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(num_positions, embedding_dim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight)

    def forward(self, input_ids):
        # input_ids:[B,HW,4](onehot)
        out = torch.matmul(input_ids, self.weight)  # [B,HW,dim]
        return out


class MPE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rel_pos_emb = MaskedSinusoidalPositionalEmbedding(num_embeddings=128,
                                                                embedding_dim=64)
        self.direct_emb = MultiLabelEmbedding(num_positions=4, embedding_dim=64)
        self.alpha5 = nn.Parameter(torch.tensor(0, dtype=torch.float32), requires_grad=True)
        self.alpha6 = nn.Parameter(torch.tensor(0, dtype=torch.float32), requires_grad=True)

    def forward(self, rel_pos=None, direct=None):
        b, h, w = rel_pos.shape
        rel_pos = rel_pos.reshape(b, h * w)
        rel_pos_emb = self.rel_pos_emb(rel_pos).reshape(b, h, w, -1).permute(0, 3, 1, 2) * self.alpha5
        direct = direct.reshape(b, h * w, 4).to(torch.float32)
        direct_emb = self.direct_emb(direct).reshape(b, h, w, -1).permute(0, 3, 1, 2) * self.alpha6

        return rel_pos_emb, direct_emb

class LamaFourier:
    def __init__(self, build_discriminator=True) -> None:
    def __init__(self, build_discriminator=True, use_mpe=False) -> None:
        # super().__init__()
        self.generator = FFCResNetGenerator(4, 3, add_out_act='sigmoid', 
                            init_conv_kwargs={
@@ -194,8 +268,13 @@ class LamaFourier:
                            'enable_lfu': False
                        }
                    )
        self.enable_fp16 = False
        self.discriminator = NLayerDiscriminator() if build_discriminator else None
        self.inpaint_only = False
        if use_mpe:
            self.mpe = MPE()
        else:
            self.mpe = None

    def train_generator(self):
        self.inpaint_only = False
@@ -205,6 +284,8 @@ class LamaFourier:
        self.discriminator.eval()
        set_requires_grad(self.discriminator, False)
        set_requires_grad(self.generator, True)
        if self.mpe is not None:
            set_requires_grad(self.mpe, True)

    def train_discriminator(self):
        self.inpaint_only = False
@@ -214,18 +295,32 @@ class LamaFourier:
        self.generator.eval()
        set_requires_grad(self.discriminator, True)
        set_requires_grad(self.generator, False)
        if self.mpe is not None:
            set_requires_grad(self.mpe, False)

    def to(self, device):
        self.generator.to(device)
        if self.discriminator is not None:
            self.discriminator.to(device)
        if self.mpe is not None:
            self.mpe.to(device)

    def eval(self):
        self.inpaint_only = True
        return self.generator.eval()
        self.generator.eval()
        if self.mpe is not None:
            self.mpe.eval()
        return self
        

    def __call__(self, img: Tensor, mask: Tensor, rel_pos=None, direct=None):

        if self.mpe is not None:
            rel_pos, direct = self.mpe(rel_pos, direct)
        else:
            rel_pos, direct = None, None
        predicted_img = self.generator(img, mask, rel_pos, direct)

    def __call__(self, img: Tensor, mask: Tensor):
        predicted_img = self.generator(img, mask)
        if self.inpaint_only:
            return predicted_img * mask + (1 - mask) * img

@@ -233,6 +328,7 @@ class LamaFourier:
            predicted_img = predicted_img.detach()
            img.requires_grad = True


        discr_real_pred, discr_real_features = self.discriminator(img)
        discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
        # fp = discr_fake_pred.detach().mean()
@@ -251,4 +347,73 @@ class LamaFourier:
                'discr_fake_pred': discr_fake_pred
            }

        
 No newline at end of file
    def load_masked_position_encoding(self, mask):
        mask = (mask * 255).astype(np.uint8)
        ones_filter = np.ones((3, 3), dtype=np.float32)
        d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
        d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32)
        d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32)
        d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32)
        str_size = 256
        pos_num = 128

        ori_mask = mask.copy()
        ori_h, ori_w = ori_mask.shape[0:2]
        ori_mask = ori_mask / 255
        mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA)
        mask[mask > 0] = 255
        h, w = mask.shape[0:2]
        mask3 = mask.copy()
        mask3 = 1. - (mask3 / 255.0)
        pos = np.zeros((h, w), dtype=np.int32)
        direct = np.zeros((h, w, 4), dtype=np.int32)
        i = 0
        while np.sum(1 - mask3) > 0:
            i += 1
            mask3_ = cv2.filter2D(mask3, -1, ones_filter)
            mask3_[mask3_ > 0] = 1
            sub_mask = mask3_ - mask3
            pos[sub_mask == 1] = i

            m = cv2.filter2D(mask3, -1, d_filter1)
            m[m > 0] = 1
            m = m - mask3
            direct[m == 1, 0] = 1

            m = cv2.filter2D(mask3, -1, d_filter2)
            m[m > 0] = 1
            m = m - mask3
            direct[m == 1, 1] = 1

            m = cv2.filter2D(mask3, -1, d_filter3)
            m[m > 0] = 1
            m = m - mask3
            direct[m == 1, 2] = 1

            m = cv2.filter2D(mask3, -1, d_filter4)
            m[m > 0] = 1
            m = m - mask3
            direct[m == 1, 3] = 1

            mask3 = mask3_

        abs_pos = pos.copy()
        rel_pos = pos / (str_size / 2)  # to 0~1 maybe larger than 1
        rel_pos = (rel_pos * pos_num).astype(np.int32)
        rel_pos = np.clip(rel_pos, 0, pos_num - 1)

        if ori_w != w or ori_h != h:
            rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
            rel_pos[ori_mask == 0] = 0
            direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
            direct[ori_mask == 0, :] = 0

        return rel_pos, abs_pos, direct

def load_lama_mpe(model_path, device) -> LamaFourier:
    model = LamaFourier(build_discriminator=False, use_mpe=True)
    sd = torch.load(model_path, map_location = 'cpu')
    model.generator.load_state_dict(sd['gen_state_dict'])
    model.mpe.load_state_dict(sd['str_state_dict'])
    model.eval().to(device)
    return model
 No newline at end of file
+6 −2
Original line number Diff line number Diff line
import time
from typing import Union
import numpy as np
import traceback


from PyQt5.QtCore import QThread, pyqtSignal, QObject, QLocale
from PyQt5.QtWidgets import QMessageBox
@@ -48,7 +50,8 @@ class ModuleThread(QThread):
        except Exception as e:
            self.module = old_module
            msg = self.tr('Failed to set ') + module_name
            self.exception_occurred.emit(msg, str(e))
            
            self.exception_occurred.emit(msg, str(e) + '\n' + f'exc: {traceback.format_exc()}')
        self.finish_set_module.emit()

    def pipeline_finished(self):
@@ -103,7 +106,8 @@ class InpaintThread(ModuleThread):
            }
            self.finish_inpaint.emit(inpaint_dict)
        except Exception as e:
            self.exception_occurred.emit(self.tr('Inpainting Failed.'), repr(e))
            # self.exception_occurred.emit(self.tr('Inpainting Failed.'), repr(e))
            self.exception_occurred.emit(self.tr('Inpainting Failed.'), str(e) + '\n' + f'exc: {traceback.format_exc()}')


class TextDetectThread(ModuleThread):