Commit a3922933 authored by dzy7e's avatar dzy7e
Browse files

focal loss

parent eb3e7946
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.ops import sigmoid_focal_loss
from tqdm.auto import tqdm

from .alexnet import MonochromeAlexNet
@@ -17,6 +18,7 @@ from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR


from accelerate import Accelerator

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
@@ -119,7 +121,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    #if torch.cuda.is_available():
    #    model = model.cuda()

    loss_fn = nn.CrossEntropyLoss()
    #loss_fn = nn.CrossEntropyLoss()
    loss_fn = lambda inputs, targets: sigmoid_focal_loss(inputs, targets, reduction='mean')
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,