Commit 904cdff6 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add dynamic learning rate

parent 2f1c86ca
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ from .dataset import MonochromeDataset
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
from ..utils import LRTyping, get_init_lr, get_dynamic_lr_scheduler

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
_LOG_DIR = os.path.join(_TRAIN_DIR, 'logs')
@@ -69,7 +70,7 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:

def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100,
          max_epochs: int = 500, learning_rate: float = 0.001,
          max_epochs: int = 500, learning_rate: LRTyping = 0.001,
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
    session_name = session_name or model_name
    _log_dir = os.path.join(_LOG_DIR, session_name)
@@ -110,7 +111,9 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        model = model.cuda()

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    initial_lr = get_init_lr(learning_rate)
    optimizer = torch.optim.Adam([{'params': model.parameters(), 'initial_lr': initial_lr}], lr=initial_lr)
    scheduler = get_dynamic_lr_scheduler(optimizer, lr=learning_rate, last_epoch=previous_epoch)

    for epoch in tqdm(range(previous_epoch + 1, max_epochs + 1)):
        running_loss = 0.0
@@ -128,7 +131,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_dataset)
        logging.info(f'Epoch {epoch} loss: {epoch_loss:.4f}')
        logging.info(f'Epoch {epoch} loss: {epoch_loss:.4f}, with learning rate: {scheduler.get_last_lr()[0]:.6f}')
        scheduler.step()
        writer.add_scalar('train/loss', epoch_loss, epoch)

        with torch.no_grad():
+1 −0
Original line number Diff line number Diff line
from .cli import GLOBAL_CONTEXT_SETTINGS, print_version
from .lr import get_init_lr, get_dynamic_lr_scheduler, LRTyping
from .optimize import onnx_optimize
from .testfile import get_testfile

zoo/utils/lr.py

0 → 100644
+47 −0
Original line number Diff line number Diff line
from typing import Tuple, List, Union, Optional

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR

LRTyping = Union[float, List[Union[Tuple[int, float], float]]]


def _process_lr(lr: LRTyping) -> List[Tuple[Optional[int], float]]:
    if isinstance(lr, float):
        return _process_lr([lr])

    sorted_items = sorted([
        (
            0 if isinstance(item, tuple) else 1,
            item[0] if isinstance(item, tuple) else -1,
            item[1] if isinstance(item, tuple) else item,
            i,
        )
        for i, item in enumerate(lr)
    ], key=lambda x: (x[0], x[1], -x[2], x[3]))
    return [(epoch if epoch >= 0 else None, _lr) for _, epoch, _lr, _ in sorted_items]


def get_init_lr(lr: LRTyping) -> float:
    if not lr:
        raise ValueError(f'Unrecognizable lr - {lr!r}.')
    lr = _process_lr(lr)
    _, _first_lr = lr[0]
    return _first_lr


def get_dynamic_lr_scheduler(optimizer: Optimizer, lr: LRTyping, **kwargs) -> LambdaLR:
    if not lr:
        raise ValueError(f'Unrecognizable lr - {lr!r}.')
    lr = _process_lr(lr)
    _, _first_lr = lr[0]

    def _epoch_to_lambda(epoch: int):
        for _ep, _lr_value in lr:
            if _ep is None or epoch <= _ep:
                return _lr_value / _first_lr
        else:
            _ep, _lr_value = lr[-1]
            return _lr_value / _first_lr

    return LambdaLR(optimizer, lr_lambda=_epoch_to_lambda, **kwargs)