Commit 43dd6fb6 authored by dzy7e's avatar dzy7e
Browse files

fix hsv

parent 4ac936bf
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -36,9 +36,10 @@ class ImageDirectoryDataset(Dataset):

    def __getitem__(self, idx):
        file_path = self.samples[idx]
        image = Image.open(file_path).convert('HSV')
        image = Image.open(file_path)
        if self.transform:
            image = self.transform(image)
        image = image.convert('HSV')
        return image_encode(image, bins=self.bins, fc=self.fc, normalize=True), torch.tensor(self.label)


+1 −1
Original line number Diff line number Diff line
@@ -92,7 +92,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

    # 使用 random_split 函数拆分数据集
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

    # Load previous epoch