Commit d122553a authored by dzy7e's avatar dzy7e
Browse files

lr_scheduler

parent dc8924ac
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -133,10 +133,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            scheduler.step()

        epoch_loss = running_loss / len(train_dataset)
        logging.info(f'Epoch [{epoch}/{max_epochs+1}] loss: {epoch_loss:.4f}, with learning rate: {scheduler.get_last_lr()[0]:.6f}')
        scheduler.step()
        #scheduler.step()
        writer.add_scalar('train/loss', epoch_loss, epoch)

        with torch.no_grad():