Commit 06e26078 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): register weight as buffer

parent 0fab8927
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -11,7 +11,9 @@ class FocalLoss(nn.Module):

    def __init__(self, weight=None, gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = torch.as_tensor(weight).float() if weight is not None else weight
        weight = torch.as_tensor(weight).float() if weight is not None else weight
        self.register_buffer('weight', weight)

        self.gamma = gamma
        self.reduction = reduction