Commit a08e43cf authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update some details

parent f4397a1a
Loading
Loading
Loading
Loading
+16 −13
Original line number Diff line number Diff line
@@ -70,8 +70,8 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:

def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100,
          max_epochs: int = 500, learning_rate: LRTyping = 0.001,
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
          max_epochs: int = 500, learning_rate: LRTyping = 0.001, num_workers: Optional[int] = None,
          device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'):
    session_name = session_name or model_name
    _log_dir = os.path.join(_LOG_DIR, session_name)
    os.makedirs(_log_dir, exist_ok=True)
@@ -90,9 +90,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    test_size = dataset_size - train_size

    # 使用 random_split 函数拆分数据集
    num_workers = num_workers or os.cpu_count()
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

    # Load previous epoch
    model = _KNOWN_MODELS[model_name]().float()
@@ -107,22 +108,23 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        logging.info(f'No checkpoint found, new model will be used.')

    # Try use cude
    if torch.cuda.is_available():
        model = model.cuda()
    if not device:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    loss_fn = nn.CrossEntropyLoss()
    initial_lr = get_init_lr(learning_rate)
    optimizer = torch.optim.Adam([{'params': model.parameters(), 'initial_lr': initial_lr}], lr=initial_lr)
    optimizer = torch.optim.AdamW(
        [{'params': model.parameters(), 'initial_lr': initial_lr}],
        lr=initial_lr, weight_decay=1e-2,
    )
    scheduler = get_dynamic_lr_scheduler(optimizer, lr=learning_rate, last_epoch=previous_epoch)

    for epoch in tqdm(range(previous_epoch + 1, max_epochs + 1)):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
            inputs = inputs.float()
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            inputs = inputs.float().to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
@@ -131,7 +133,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_dataset)
        logging.info(f'Epoch {epoch} loss: {epoch_loss:.4f}, with learning rate: {scheduler.get_last_lr()[0]:.6f}')
        logging.info(f'Epoch [{epoch}/{max_epochs + 1}] loss: {epoch_loss:.4f}, '
                     f'with learning rate: {scheduler.get_last_lr()[0]:.6f}')
        scheduler.step()
        writer.add_scalar('train/loss', epoch_loss, epoch)