Commit e88c69d8 authored by dzy7e's avatar dzy7e
Browse files

update

parent 3e7a6001
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -142,7 +142,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,
        steps_per_epoch=len(train_dataloader), epochs=max_epochs,
        steps_per_epoch=len(train_dataloader)//accelerator.num_processes, epochs=max_epochs,
        pct_start=0.15, final_div_factor=20.
    )
    # model = torch.compile(model)