Commit f4dab666 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save every 10 epochs

parent d771f620
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -49,7 +49,8 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:


def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float = 0.8,
          batch_size: int = 4, feature_bins: int = 400, max_epochs: int = 500, learning_rate: float = 0.001):
          batch_size: int = 4, feature_bins: int = 400, max_epochs: int = 500, learning_rate: float = 0.001,
          save_per_epoch: int = 10):
    os.makedirs(_LOG_DIR, exist_ok=True)
    os.makedirs(_CKPT_DIR, exist_ok=True)
    writer = SummaryWriter(_LOG_DIR)
@@ -136,6 +137,7 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
            logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}')
            writer.add_scalar('test/accuracy', test_accuracy, epoch)

        if epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{epoch}.ckpt')
            torch.save(model.state_dict(), current_ckpt_file)
            logging.info(f'Saved to {current_ckpt_file!r}.')