Commit 153a2d8a authored by dzy7e's avatar dzy7e
Browse files

update

parent a6950e04
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -177,6 +177,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                    mean_loss = running_loss/train_pos_total
                    if writer:
                        writer.add_scalar('train/loss', mean_loss, epoch*num_iter + i)
                    running_loss = 0.
                    train_pos_total = 0

                if (i+1)%log_iter == 0:
                    pred_t = torch.cat(pred_list).to(accelerator.device)
@@ -193,8 +195,6 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

                        pred_list.clear()
                        gt_list.clear()
                        running_loss = 0.
                        train_pos_total = 0

        model.eval()
        if epoch%eval_epoch == 0: