Commit b8fc34f8 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add learning_rate parameter

parent 1cb3a41e
Loading
Loading
Loading
Loading
+2 −6
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:


def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float = 0.8,
          batch_size: int = 4, feature_bins: int = 400, max_epochs: int = 500):
          batch_size: int = 4, feature_bins: int = 400, max_epochs: int = 500, learning_rate: float = 0.001):
    os.makedirs(_LOG_DIR, exist_ok=True)
    os.makedirs(_CKPT_DIR, exist_ok=True)
    writer = SummaryWriter(_LOG_DIR)
@@ -81,11 +81,7 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
        model = model.cuda()

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.001,
        betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
        amsgrad=False
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0