Loading zoo/monochrome/train_.py +4 −10 Original line number Diff line number Diff line Loading @@ -16,7 +16,6 @@ from .dataset import MonochromeDataset from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR from ..utils import LRTyping, get_init_lr _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome') _LOG_DIR = os.path.join(_TRAIN_DIR, 'logs') Loading Loading @@ -71,9 +70,9 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 180, fc: Optional[int] = 75, max_epochs: int = 500, learning_rate: LRTyping = 0.001, weight_decay: float = 1e-3, num_workers: Optional[int] = None, device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'): max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3, num_workers: Optional[int] = None, device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'): session_name = session_name or model_name _log_dir = os.path.join(_LOG_DIR, session_name) os.makedirs(_log_dir, exist_ok=True) Loading @@ -100,7 +99,6 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio # Load previous epoch model = _KNOWN_MODELS[model_name]().float() # model = MonochromeAlexNet().float() if from_ckpt is None: from_ckpt = _find_latest_ckpt(session_name) previous_epoch = _ckpt_epoch(from_ckpt) or 0 Loading @@ -116,11 +114,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio model = model.to(device) loss_fn = nn.CrossEntropyLoss() initial_lr = get_init_lr(learning_rate) optimizer = torch.optim.AdamW( [{'params': model.parameters(), 'initial_lr': initial_lr}], lr=initial_lr, weight_decay=weight_decay, ) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, steps_per_epoch=len(train_dataloader), epochs=max_epochs, Loading Loading
zoo/monochrome/train_.py +4 −10 Original line number Diff line number Diff line Loading @@ -16,7 +16,6 @@ from .dataset import MonochromeDataset from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR from ..utils import LRTyping, get_init_lr _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome') _LOG_DIR = os.path.join(_TRAIN_DIR, 'logs') Loading Loading @@ -71,9 +70,9 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 180, fc: Optional[int] = 75, max_epochs: int = 500, learning_rate: LRTyping = 0.001, weight_decay: float = 1e-3, num_workers: Optional[int] = None, device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'): max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3, num_workers: Optional[int] = None, device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'): session_name = session_name or model_name _log_dir = os.path.join(_LOG_DIR, session_name) os.makedirs(_log_dir, exist_ok=True) Loading @@ -100,7 +99,6 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio # Load previous epoch model = _KNOWN_MODELS[model_name]().float() # model = MonochromeAlexNet().float() if from_ckpt is None: from_ckpt = _find_latest_ckpt(session_name) previous_epoch = _ckpt_epoch(from_ckpt) or 0 Loading @@ -116,11 +114,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio model = model.to(device) loss_fn = nn.CrossEntropyLoss() initial_lr = get_init_lr(learning_rate) optimizer = torch.optim.AdamW( [{'params': model.parameters(), 'initial_lr': initial_lr}], lr=initial_lr, weight_decay=weight_decay, ) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, steps_per_epoch=len(train_dataloader), epochs=max_epochs, Loading