Loading zoo/monochrome/__main__.py +2 −1 Original line number Diff line number Diff line Loading @@ -12,7 +12,7 @@ from .train_ import _KNOWN_MODELS from ..utils import GLOBAL_CONTEXT_SETTINGS from ..utils import print_version as _origin_print_version print_version = partial(_origin_print_version, 'zoo.lpips') print_version = partial(_origin_print_version, 'zoo.monochrome') @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}) Loading Loading @@ -41,6 +41,7 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str): _KNOWN_CKPTS: List[Tuple[str, str, int]] = [ ('monochrome-alexnet_plus-320.ckpt', 'alexnet', 256), ('monochrome-alexnet_plus-500.ckpt', 'alexnet', 256), ] Loading zoo/monochrome/dataset.py +6 −0 Original line number Diff line number Diff line Loading @@ -36,6 +36,12 @@ class ImageDirectoryDataset(Dataset): def __getitem__(self, idx): file_path = self.samples[idx] # ATTENTION: DO NOT REMOVE THIS CONVERT, THIS IS IMPORTANT!!! # In torchvision.transforms.functional.pad, the RGB mode will be used when your input image # have 3 channels (no matter RGB, LAB or HSV), so the transformed image which actually passed into # model should be processed like this: # image = Image.fromarray(np.asarray(image.convert("HSV")), mode='RGB') # and then use `image_encode` to encode. image = Image.open(file_path).convert('HSV') if self.transform: image = self.transform(image) Loading zoo/monochrome/train_.py +6 −12 Original line number Diff line number Diff line Loading @@ -71,7 +71,7 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100, max_epochs: int = 500, learning_rate: LRTyping = 0.001, weight_decay: float = 1e-4, num_workers: Optional[int] = None, weight_decay: float = 1e-2, num_workers: Optional[int] = None, device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'): session_name = session_name or model_name _log_dir = os.path.join(_LOG_DIR, session_name) Loading Loading @@ -104,7 +104,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio previous_epoch = _ckpt_epoch(from_ckpt) or 0 if from_ckpt: logging.info(f'Load checkpoint from {from_ckpt!r}.') model.load_state_dict(torch.load(from_ckpt)) model.load_state_dict(torch.load(from_ckpt, map_location='cpu')) else: logging.info(f'No checkpoint found, new model will be used.') Loading Loading @@ -142,11 +142,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio with torch.no_grad(): train_correct = 0 for i, (inputs, labels) in enumerate(tqdm(train_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.float().to(device) labels = labels.to(device) outputs = model(inputs) train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() Loading @@ -156,11 +153,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio test_correct = 0 for i, (inputs, labels) in enumerate(tqdm(test_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.float().to(device) labels = labels.to(device) outputs = model(inputs) test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() Loading Loading
zoo/monochrome/__main__.py +2 −1 Original line number Diff line number Diff line Loading @@ -12,7 +12,7 @@ from .train_ import _KNOWN_MODELS from ..utils import GLOBAL_CONTEXT_SETTINGS from ..utils import print_version as _origin_print_version print_version = partial(_origin_print_version, 'zoo.lpips') print_version = partial(_origin_print_version, 'zoo.monochrome') @click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS}) Loading Loading @@ -41,6 +41,7 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str): _KNOWN_CKPTS: List[Tuple[str, str, int]] = [ ('monochrome-alexnet_plus-320.ckpt', 'alexnet', 256), ('monochrome-alexnet_plus-500.ckpt', 'alexnet', 256), ] Loading
zoo/monochrome/dataset.py +6 −0 Original line number Diff line number Diff line Loading @@ -36,6 +36,12 @@ class ImageDirectoryDataset(Dataset): def __getitem__(self, idx): file_path = self.samples[idx] # ATTENTION: DO NOT REMOVE THIS CONVERT, THIS IS IMPORTANT!!! # In torchvision.transforms.functional.pad, the RGB mode will be used when your input image # have 3 channels (no matter RGB, LAB or HSV), so the transformed image which actually passed into # model should be processed like this: # image = Image.fromarray(np.asarray(image.convert("HSV")), mode='RGB') # and then use `image_encode` to encode. image = Image.open(file_path).convert('HSV') if self.transform: image = self.transform(image) Loading
zoo/monochrome/train_.py +6 −12 Original line number Diff line number Diff line Loading @@ -71,7 +71,7 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100, max_epochs: int = 500, learning_rate: LRTyping = 0.001, weight_decay: float = 1e-4, num_workers: Optional[int] = None, weight_decay: float = 1e-2, num_workers: Optional[int] = None, device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'): session_name = session_name or model_name _log_dir = os.path.join(_LOG_DIR, session_name) Loading Loading @@ -104,7 +104,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio previous_epoch = _ckpt_epoch(from_ckpt) or 0 if from_ckpt: logging.info(f'Load checkpoint from {from_ckpt!r}.') model.load_state_dict(torch.load(from_ckpt)) model.load_state_dict(torch.load(from_ckpt, map_location='cpu')) else: logging.info(f'No checkpoint found, new model will be used.') Loading Loading @@ -142,11 +142,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio with torch.no_grad(): train_correct = 0 for i, (inputs, labels) in enumerate(tqdm(train_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.float().to(device) labels = labels.to(device) outputs = model(inputs) train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() Loading @@ -156,11 +153,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio test_correct = 0 for i, (inputs, labels) in enumerate(tqdm(test_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.float().to(device) labels = labels.to(device) outputs = model(inputs) test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() Loading