Commit 4231f1c9 authored by dmMaze's avatar dmMaze
Browse files

add YSGYoloDetector

parent cb8ed0b9
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -137,6 +137,7 @@ Sugoi 翻译器作者: [mingshiba](https://www.patreon.com/mingshiba)
 * 暂时仅支持日文(方块字都差不多)和英文检测,训练代码和说明见https://github.com/dmMaze/comic-text-detector
 * 支持使用 [星河云(团子漫画OCR)](https://cloud.stariver.org.cn/)的文本检测,需要填写用户名和密码,每次启动时会自动登录。
   * 详细说明见 [团子OCR说明](doc/团子OCR说明.md)
 * `YSGDetector` 是由 [lhj5426](https://github.com/lhj5426) 训练的模型,能更好地过滤日漫/CG里的拟声词。需要手动从 [YSGYoloDetector](https://huggingface.co/dreMaz/YSGYoloDetector) 下载模型放到 data/models 目录下。


### OCR
+5 −1
Original line number Diff line number Diff line
@@ -139,6 +139,10 @@ This project is heavily dependent upon [manga-image-translator](https://github.c
 * Support using text detection from [Starriver Cloud (Tuanzi Manga OCR)](https://cloud.stariver.org.cn/). Username and password need to be filled in, and automatic login will be performed each time the program is launched.

   * For detailed instructions, see **Tuanzi OCR Instructions**: ([Chinese](doc/团子OCR说明.md) & [Brazilian Portuguese](doc/Manual_TuanziOCR_pt-BR.md) only)
 
 * `YSGDetector` models are trained by [lhj5426](https://github.com/lhj5426), these models would filter out onomatopoeia in CGs/Manga, download checkpoints from [YSGYoloDetector](https://huggingface.co/dreMaz/YSGYoloDetector) and put into `data/models`. 


## OCR
 * All mit* models are from manga-image-translator, support English, Japanese and Korean recognition and text color extraction.
 * [manga_ocr](https://github.com/kha-white/manga-ocr) is from [kha-white](https://github.com/kha-white), text recognition for Japanese, with the main focus being Japanese manga.
+5 −0
Original line number Diff line number Diff line
@@ -126,6 +126,8 @@ class BaseModule:
                if hasattr(self, k):
                    model = getattr(self, k)
                    if model is not None:
                        if hasattr(model, 'unload_model'):
                            model.unload_model(empty_cache=False)
                        del model
                        setattr(self, k, None)
                        model_deleted = True
@@ -158,6 +160,9 @@ class BaseModule:
    def debug_mode(self):
        return shared.DEBUG
    
    def flush(self, param_key: str):
        return None

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch

+8 −3
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from typing import Union, List, Tuple
from collections import OrderedDict

from utils.textblock import TextBlock
from utils.proj_imgtrans import ProjImgTrans

from utils.registry import Registry
TEXTDETECTORS = Registry('textdetectors')
@@ -26,16 +27,20 @@ class TextDetectorBase(BaseModule):
                self.name = key
                break

    def _detect(self, *args, **kwargs) -> Tuple[np.ndarray, List[TextBlock]]:
    def _detect(self, img: np.ndarray, proj: ProjImgTrans) -> Tuple[np.ndarray, List[TextBlock]]:
        '''
        The proj context can be accessed via ```proj```
        '''
        raise NotImplementedError

    def setup_detector(self):
        raise NotImplementedError

    def detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
    def detect(self, img: np.ndarray, proj: ProjImgTrans = None) -> Tuple[np.ndarray, List[TextBlock]]:
        # TODO: allow processing proj entirely in _detect and yield progress
        if not self.all_model_loaded():
            self.load_model()
        mask, blk_list = self._detect(img)
        mask, blk_list = self._detect(img, proj)
        for blk in blk_list:
            blk.det_model = self.name
        return mask, blk_list
+2 −2
Original line number Diff line number Diff line
import numpy as np
from typing import Tuple, List

from .base import register_textdetectors, TextDetectorBase, TextBlock, DEFAULT_DEVICE, DEVICE_SELECTOR
from .base import register_textdetectors, TextDetectorBase, TextBlock, DEFAULT_DEVICE, DEVICE_SELECTOR, ProjImgTrans
from .ctd import CTDModel

CTD_ONNX_PATH = 'data/models/comictextdetector.pt.onnx'
@@ -57,7 +57,7 @@ class ComicTextDetector(TextDetectorBase):
        else:
            self.model = load_ctd_model(CTD_ONNX_PATH, self.device, self.detect_size)

    def _detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
    def _detect(self, img: np.ndarray, proj: ProjImgTrans) -> Tuple[np.ndarray, List[TextBlock]]:
        _, mask, blk_list = self.model(img)
        return mask, blk_list

Loading