Loading zoo/monochrome/train_.py +4 −1 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ from torch import nn from torch.optim import lr_scheduler from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision.ops import sigmoid_focal_loss from tqdm.auto import tqdm from .alexnet import MonochromeAlexNet Loading @@ -17,6 +18,7 @@ from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR from accelerate import Accelerator _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome') Loading Loading @@ -119,7 +121,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio #if torch.cuda.is_available(): # model = model.cuda() loss_fn = nn.CrossEntropyLoss() #loss_fn = nn.CrossEntropyLoss() loss_fn = lambda inputs, targets: sigmoid_focal_loss(inputs, targets, reduction='mean') optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, Loading Loading
zoo/monochrome/train_.py +4 −1 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ from torch import nn from torch.optim import lr_scheduler from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchvision.ops import sigmoid_focal_loss from tqdm.auto import tqdm from .alexnet import MonochromeAlexNet Loading @@ -17,6 +18,7 @@ from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR from accelerate import Accelerator _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome') Loading Loading @@ -119,7 +121,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio #if torch.cuda.is_available(): # model = model.cuda() loss_fn = nn.CrossEntropyLoss() #loss_fn = nn.CrossEntropyLoss() loss_fn = lambda inputs, targets: sigmoid_focal_loss(inputs, targets, reduction='mean') optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, Loading