Commit 114afba0 authored by dmMaze's avatar dmMaze
Browse files

support ysg RTDERT, close #870

parent 6260abf0
Loading
Loading
Loading
Loading
+12 −7
Original line number Diff line number Diff line
@@ -75,16 +75,21 @@ class YSGYoloDetector(TextDetectorBase):
        'mask dilate size': 2
    }

    _load_model_keys = {'yolo'}
    _load_model_keys = {'model'}

    def __init__(self, **params) -> None:
        super().__init__(**params)
        update_ckpt_list()
    
    def _load_model(self):
        from ultralytics import YOLO
        if not hasattr(self, 'yolo') or self.yolo is None:
            self.yolo = YOLO(self.get_param_value('model path')).to(device=self.get_param_value('device'))
        model_path = self.get_param_value('model path')

        if 'rtdetr' in os.path.basename(model_path):
            from ultralytics import RTDETR as MODEL
        else:
            from ultralytics import YOLO as MODEL
        if not hasattr(self, 'model') or self.model is None:
            self.model = MODEL(model_path).to(device=self.get_param_value('device'))

    def get_valid_labels(self):
        valid_labels = [k for k, v in self.params['label']['value'].items() if v and k != 'textblock']
@@ -96,7 +101,7 @@ class YSGYoloDetector(TextDetectorBase):

    def _detect(self, img: np.ndarray, proj: ProjImgTrans = None) -> Tuple[np.ndarray, List[TextBlock]]:

        result = self.yolo.predict(
        result = self.model.predict(
            source=img, save=False, show=False, verbose=False, 
            conf=self.get_param_value('confidence threshold'), iou=self.get_param_value('IoU threshold'),
            agnostic_nms=True
@@ -233,8 +238,8 @@ class YSGYoloDetector(TextDetectorBase):
        super().updateParam(param_key, param_content)
        
        if param_key == 'model path':
            if hasattr(self, 'yolo'):
                del self.yolo
            if hasattr(self, 'model'):
                del self.model

    def flush(self, param_key: str):
        if param_key == 'model path':