Commit 3b099aa7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix a bug in loss

parent 3aa64352
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -148,7 +148,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            loss = loss_fn(outputs, labels)
            accelerator.backward(loss)
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_loss += loss.item() * labels.shape[0]
            scheduler.step()

        epoch_loss = running_loss / train_total