Loading zoo/monochrome/dataset.py +2 −1 Original line number Diff line number Diff line Loading @@ -36,9 +36,10 @@ class ImageDirectoryDataset(Dataset): def __getitem__(self, idx): file_path = self.samples[idx] image = Image.open(file_path).convert('HSV') image = Image.open(file_path) if self.transform: image = self.transform(image) image = image.convert('HSV') return image_encode(image, bins=self.bins, fc=self.fc, normalize=True), torch.tensor(self.label) Loading zoo/monochrome/train_.py +1 −1 Original line number Diff line number Diff line Loading @@ -92,7 +92,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio # 使用 random_split 函数拆分数据集 train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers) # Load previous epoch Loading Loading
zoo/monochrome/dataset.py +2 −1 Original line number Diff line number Diff line Loading @@ -36,9 +36,10 @@ class ImageDirectoryDataset(Dataset): def __getitem__(self, idx): file_path = self.samples[idx] image = Image.open(file_path).convert('HSV') image = Image.open(file_path) if self.transform: image = self.transform(image) image = image.convert('HSV') return image_encode(image, bins=self.bins, fc=self.fc, normalize=True), torch.tensor(self.label) Loading
zoo/monochrome/train_.py +1 −1 Original line number Diff line number Diff line Loading @@ -92,7 +92,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio # 使用 random_split 函数拆分数据集 train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers) # Load previous epoch Loading