Commit c86b3686 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use focal loss

parent 7c9330b4
Loading
Loading
Loading
Loading

zoo/monochrome/loss.py

0 → 100644
+26 −0
Original line number Diff line number Diff line
import torch
from torch import nn

from torch.nn import functional as F


class FocalLoss(nn.Module):
    """
    Based on https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8
    """

    def __init__(self, weight=None, gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = torch.as_tensor(weight).float() if weight else weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )
+2 −1
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ from tqdm.auto import tqdm

from .alexnet import MonochromeAlexNet
from .dataset import MonochromeDataset
from .loss import FocalLoss
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
@@ -118,7 +119,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference
    else:
        loss_weight = torch.as_tensor([1.0, torch.e]) ** preference
    loss_fn = nn.CrossEntropyLoss(weight=loss_weight).to(device)
    loss_fn = FocalLoss(weight=loss_weight).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,