Commit 773c57ff authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add better train preference && add export for new models

parent 0406b106
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -40,8 +40,8 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str):


_KNOWN_CKPTS: List[Tuple[str, str, int]] = [
    ('monochrome-alexnet_plus-320.ckpt', 'alexnet', 256),
    ('monochrome-alexnet_plus-500.ckpt', 'alexnet', 256),
    ('monochrome-alexnet-480.ckpt', 'alexnet', 180),
    ('monochrome-resnet18-480.ckpt', 'resnet18', 180),
]


+30 −7
Original line number Diff line number Diff line
@@ -70,7 +70,7 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:

def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 180, fc: Optional[int] = 75,
          max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3,
          max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3, preference: float = 0.0,
          num_workers: Optional[int] = None, device: Optional[str] = None,
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
    session_name = session_name or model_name
@@ -81,6 +81,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    writer.add_custom_scalars({
        "general": {
            "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]],
            "false": ["Multiline", ["train/fp", "train/fn", "test/fp", "test/fn"]],
        },
    })

@@ -91,7 +92,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    test_size = dataset_size - train_size

    # 使用 random_split 函数拆分数据集
    num_workers = num_workers or os.cpu_count()
    num_workers = num_workers or min(os.cpu_count(), batch_size)
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                  drop_last=True)
@@ -113,7 +114,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    loss_fn = nn.CrossEntropyLoss()
    if preference < 0:
        loss_weight = torch.as_tensor([torch.e, 1.0]) ** -preference
    else:
        loss_weight = torch.as_tensor([1.0, torch.e]) ** preference
    loss_fn = nn.CrossEntropyLoss(weight=loss_weight)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,
@@ -141,28 +146,46 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

        with torch.no_grad():
            train_correct, train_total = 0, 0
            train_fp, train_fn = 0, 0
            for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
                inputs = inputs.float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
                preds = torch.argmax(outputs, dim=1)
                train_correct += (preds == labels).sum().item()
                train_fp += (preds[labels == 0] == 1).sum().item()
                train_fn += (preds[labels == 1] == 0).sum().item()
                train_total += labels.shape[0]

            train_accuracy = train_correct / train_total
            logging.info(f'Epoch {epoch} train accuracy: {train_accuracy:.4f}')
            train_fp_p = train_fp / train_total
            train_fn_p = train_fn / train_total
            logging.info(f'Epoch {epoch}, train accuracy: {train_accuracy:.4f}, '
                         f'false positive: {train_fp_p:.4f}, false negative: {train_fn_p:.4f}')
            writer.add_scalar('train/accuracy', train_accuracy, epoch)
            writer.add_scalar('train/fp', train_fp_p, epoch)
            writer.add_scalar('train/fn', train_fn_p, epoch)

            test_correct, test_total = 0, 0
            test_fp, test_fn = 0, 0
            for i, (inputs, labels) in enumerate(tqdm(test_dataloader)):
                inputs = inputs.float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
                preds = torch.argmax(outputs, dim=1)
                test_correct += (preds == labels).sum().item()
                test_fp += (preds[labels == 0] == 1).sum().item()
                test_fn += (preds[labels == 1] == 0).sum().item()
                test_total += labels.shape[0]

            test_accuracy = test_correct / test_total
            logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}')
            test_fp_p = test_fp / test_total
            test_fn_p = test_fn / test_total
            logging.info(f'Epoch {epoch}, test accuracy: {test_accuracy:.4f}, '
                         f'false positive: {test_fp_p:.4f}, false negative: {test_fn_p:.4f}')
            writer.add_scalar('test/accuracy', test_accuracy, epoch)
            writer.add_scalar('test/fp', test_fp_p, epoch)
            writer.add_scalar('test/fn', test_fn_p, epoch)

        if epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{session_name}-{epoch}.ckpt')