Commit 9f3440c2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix bug on cuda

parent c86b3686
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ class FocalLoss(nn.Module):

    def __init__(self, weight=None, gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = torch.as_tensor(weight).float() if weight else weight
        self.weight = torch.as_tensor(weight).float() if weight is not None else weight
        self.gamma = gamma
        self.reduction = reduction

+1 −1
Original line number Diff line number Diff line
@@ -119,7 +119,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference
    else:
        loss_weight = torch.as_tensor([1.0, torch.e]) ** preference
    loss_fn = FocalLoss(weight=loss_weight).to(device)
    loss_fn = FocalLoss(weight=loss_weight.to(device)).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,