Commit 25d3ed20 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update train function

parent d47c8bff
Loading
Loading
Loading
Loading
+4 −10
Original line number Diff line number Diff line
@@ -16,7 +16,6 @@ from .dataset import MonochromeDataset
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
from ..utils import LRTyping, get_init_lr

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
_LOG_DIR = os.path.join(_TRAIN_DIR, 'logs')
@@ -71,9 +70,9 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:

def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 180, fc: Optional[int] = 75,
          max_epochs: int = 500, learning_rate: LRTyping = 0.001,
          weight_decay: float = 1e-3, num_workers: Optional[int] = None,
          device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'):
          max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3,
          num_workers: Optional[int] = None, device: Optional[str] = None,
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
    session_name = session_name or model_name
    _log_dir = os.path.join(_LOG_DIR, session_name)
    os.makedirs(_log_dir, exist_ok=True)
@@ -100,7 +99,6 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

    # Load previous epoch
    model = _KNOWN_MODELS[model_name]().float()
    # model = MonochromeAlexNet().float()
    if from_ckpt is None:
        from_ckpt = _find_latest_ckpt(session_name)
    previous_epoch = _ckpt_epoch(from_ckpt) or 0
@@ -116,11 +114,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    model = model.to(device)

    loss_fn = nn.CrossEntropyLoss()
    initial_lr = get_init_lr(learning_rate)
    optimizer = torch.optim.AdamW(
        [{'params': model.parameters(), 'initial_lr': initial_lr}],
        lr=initial_lr, weight_decay=weight_decay,
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,
        steps_per_epoch=len(train_dataloader), epochs=max_epochs,