Commit 6cb11538 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): upadate from main

parent 0271d826
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -104,6 +104,13 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        writer.add_custom_scalars({
            "general": {
                "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]],
                "false": ["Multiline", ["test/fn", "test/fp", "train/fn", "train/fp"]],
            },
            "test": {
                "false": ["Multiline", ["test/fn", "test/fp"]],
            },
            "train": {
                "false": ["Multiline", ["train/fn", "train/fp"]],
            },
        })
    else:
@@ -116,6 +123,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    test_size = dataset_size - train_size

    # 使用 random_split 函数拆分数据集
    num_workers = num_workers or min(os.cpu_count(), batch_size)
    train_dataset, test_dataset = random_split_dataset(
        full_dataset, train_size, test_size,
        trans_val=TRANSFORM2_VAL if full_dataset.__dims__ == 2 else TRANSFORM_VAL