Loading zoo/monochrome/loss.py 0 → 100644 +26 −0 Original line number Diff line number Diff line import torch from torch import nn from torch.nn import functional as F class FocalLoss(nn.Module): """ Based on https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8 """ def __init__(self, weight=None, gamma=2., reduction='mean'): nn.Module.__init__(self) self.weight = torch.as_tensor(weight).float() if weight else weight self.gamma = gamma self.reduction = reduction def forward(self, input_tensor, target_tensor): log_prob = F.log_softmax(input_tensor, dim=-1) prob = torch.exp(log_prob) return F.nll_loss( ((1 - prob) ** self.gamma) * log_prob, target_tensor, weight=self.weight, reduction=self.reduction ) zoo/monochrome/train_.py +2 −1 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ from tqdm.auto import tqdm from .alexnet import MonochromeAlexNet from .dataset import MonochromeDataset from .loss import FocalLoss from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading Loading @@ -118,7 +119,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference else: loss_weight = torch.as_tensor([1.0, torch.e]) ** preference loss_fn = nn.CrossEntropyLoss(weight=loss_weight).to(device) loss_fn = FocalLoss(weight=loss_weight).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, Loading Loading
zoo/monochrome/loss.py 0 → 100644 +26 −0 Original line number Diff line number Diff line import torch from torch import nn from torch.nn import functional as F class FocalLoss(nn.Module): """ Based on https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8 """ def __init__(self, weight=None, gamma=2., reduction='mean'): nn.Module.__init__(self) self.weight = torch.as_tensor(weight).float() if weight else weight self.gamma = gamma self.reduction = reduction def forward(self, input_tensor, target_tensor): log_prob = F.log_softmax(input_tensor, dim=-1) prob = torch.exp(log_prob) return F.nll_loss( ((1 - prob) ** self.gamma) * log_prob, target_tensor, weight=self.weight, reduction=self.reduction )
zoo/monochrome/train_.py +2 −1 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ from tqdm.auto import tqdm from .alexnet import MonochromeAlexNet from .dataset import MonochromeDataset from .loss import FocalLoss from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading Loading @@ -118,7 +119,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference else: loss_weight = torch.as_tensor([1.0, torch.e]) ** preference loss_fn = nn.CrossEntropyLoss(weight=loss_weight).to(device) loss_fn = FocalLoss(weight=loss_weight).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate, Loading