Commit de58de0d authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add manbits training script

parent 629b1ee8
Loading
Loading
Loading
Loading
+26 −0
Original line number Diff line number Diff line
import os.path

from ultralytics import YOLO

from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'manbits_detect')


def train(train_cfg: str, session_name: str, level: str = 'm',
          max_epochs: int = 300, save_per_epoch: int = 10, **kwargs):
    # Load a pretrained YOLO model (recommended for training)
    _last_pt = os.path.join(_TRAIN_DIR, session_name, 'weights', 'last.pt')
    if os.path.exists(_last_pt):
        model, resume = YOLO(_last_pt), True
    else:
        model, resume = YOLO(f'yolov8{level}.pt'), False

    # Train the model using the 'coco128.yaml' dataset for 3 epochs
    model.train(
        data=train_cfg, epochs=max_epochs,
        name=session_name, project=_TRAIN_DIR,
        save=True, save_period=save_per_epoch, plots=True,
        exist_ok=True, resume=resume,
        **kwargs
    )
+4 −4
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ def export_yolo_to_onnx(yolo: YOLO, onnx_filename, opset_version: int = 14,
                        no_optimize: bool = False):
    if os.path.dirname(onnx_filename):
        os.makedirs(os.path.dirname(onnx_filename), exist_ok=True)
    copy(
        yolo.export(format='onnx', dynamic=True, simplify=not no_optimize, opset=opset_version),
        onnx_filename
    )

    _retval = yolo.export(format='onnx', dynamic=True, simplify=not no_optimize, opset=opset_version)
    _exported_onnx_file = _retval or (os.path.splitext(yolo.ckpt_path)[0] + '.onnx')
    copy(_exported_onnx_file, onnx_filename)