Commit 02ad00a2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save the model at the last position

parent c52090ce
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -107,10 +107,6 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
        logging.info(f'Epoch {epoch} loss: {epoch_loss:.4f}')
        writer.add_scalar('train/loss', epoch_loss, epoch)

        current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{epoch}.ckpt')
        torch.save(model.state_dict(), current_ckpt_file)
        logging.info(f'Saved to {current_ckpt_file!r}.')

        with torch.no_grad():
            train_correct = 0
            for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
@@ -139,3 +135,7 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
            test_accuracy = test_correct / len(test_dataset)
            logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}')
            writer.add_scalar('test/accuracy', test_accuracy, epoch)

        current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{epoch}.ckpt')
        torch.save(model.state_dict(), current_ckpt_file)
        logging.info(f'Saved to {current_ckpt_file!r}.')