Commit 6bb9bbc6 authored by dzy7e's avatar dzy7e
Browse files

focal loss

parent a3922933
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -121,8 +121,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    #if torch.cuda.is_available():
    #    model = model.cuda()

    #loss_fn = nn.CrossEntropyLoss()
    loss_fn = lambda inputs, targets: sigmoid_focal_loss(inputs, targets, reduction='mean')
    loss_fn = nn.CrossEntropyLoss()
    #loss_fn = lambda inputs, targets: sigmoid_focal_loss(inputs, targets, reduction='mean')
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,