Commit 82764754 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo); update script

parent 560f4a01
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'face_detect')


def train(train_cfg: str, session_name: str, level: str = 's', max_epochs: int = 300, **kwargs):
def train(train_cfg: str, session_name: str, level: str = 's',
          max_epochs: int = 300, save_per_epoch: int = 10, **kwargs):
    # Load a pretrained YOLO model (recommended for training)
    _last_pt = os.path.join(_TRAIN_DIR, session_name, 'weights', 'last.pt')
    if os.path.exists(_last_pt):
@@ -19,6 +20,6 @@ def train(train_cfg: str, session_name: str, level: str = 's', max_epochs: int =
    model.train(
        data=train_cfg, epochs=max_epochs,
        name=session_name, project=_TRAIN_DIR,
        exist_ok=True, save=True, resume=resume,
        exist_ok=True, save=True, save_period=save_per_epoch, resume=resume,
        **kwargs
    )
+3 −2
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'person_detect')


def train(train_cfg: str, session_name: str, level: str = 's', max_epochs: int = 300, **kwargs):
def train(train_cfg: str, session_name: str, level: str = 's',
          max_epochs: int = 300, save_per_epoch: int = 10, **kwargs):
    # Load a pretrained YOLO model (recommended for training)
    _last_pt = os.path.join(_TRAIN_DIR, session_name, 'weights', 'last.pt')
    if os.path.exists(_last_pt):
@@ -19,6 +20,6 @@ def train(train_cfg: str, session_name: str, level: str = 's', max_epochs: int =
    model.train(
        data=train_cfg, epochs=max_epochs,
        name=session_name, project=_TRAIN_DIR,
        exist_ok=True, save=True, resume=resume,
        exist_ok=True, save=True, save_period=save_per_epoch, resume=resume,
        **kwargs
    )