Commit d771f620 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add tqdm to epochs

parent 02ad00a2
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -88,7 +88,7 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(previous_epoch + 1, max_epochs + 1):
    for epoch in tqdm(range(previous_epoch + 1, max_epochs + 1)):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
            inputs = inputs.float()