Commit 195cc8f4 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save the fixes

parent c7469ff6
Loading
Loading
Loading
Loading
+8 −5
Original line number Diff line number Diff line
@@ -288,7 +288,10 @@ class YOLOModel:
                token=self._get_hf_token(),
            ))
            model_metadata = model.get_modelmeta()
            if 'imgsz' in model_metadata.custom_metadata_map:
                max_infer_size = max(json.loads(model_metadata.custom_metadata_map['imgsz']))
            else:
                max_infer_size = 640
            names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names'])
            labels = ['<unknown>'] * (max(names_map.keys()) + 1)
            for id_, name in names_map.items():
@@ -299,7 +302,7 @@ class YOLOModel:

    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
            -> Tuple[Tuple[int, int, int, int], str, float]:
            -> List[Tuple[Tuple[int, int, int, int], str, float]]:
        model, max_infer_size, labels = self._open_model(model_name)
        image = load_image(image, mode='RGB')
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
@@ -316,9 +319,9 @@ def _open_models_for_repo_id(repo_id: str) -> YOLOModel:
    return YOLOModel(repo_id)


def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
                 conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
        -> Tuple[Tuple[int, int, int, int], str, float]:
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    return _open_models_for_repo_id(repo_id).predict(
        image=image,
        model_name=model_name,