Loading zoo/monochrome/train_.py +4 −2 Original line number Diff line number Diff line Loading @@ -28,12 +28,14 @@ _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') _CKPT_PATTERN = re.compile(r'^monochrome-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$') _KNOWN_MODELS = {} _KNOWN_DIMS = {} _KNOWN_DATASETS = {} def _register_model(cls: Type[nn.Module], *args, name=None, **kwargs): name = name or cls.__model_name__ _KNOWN_MODELS[name] = partial(cls, *args, **kwargs) _KNOWN_DIMS[name] = getattr(cls, '__dims__', 1) _register_model(MonochromeAlexNet) Loading Loading @@ -94,7 +96,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio ) session_name = session_name or model_name model_dims = getattr(_KNOWN_MODELS[model_name], '__dims__', 1) model_dims = _KNOWN_DIMS[model_name] _log_dir = os.path.join(_LOG_DIR, session_name) if accelerator.is_local_main_process: Loading Loading @@ -126,7 +128,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio num_workers = num_workers or min(os.cpu_count(), batch_size) train_dataset, test_dataset = random_split_dataset( full_dataset, train_size, test_size, trans_val=TRANSFORM2_VAL if full_dataset.__dims__ == 2 else TRANSFORM_VAL trans_val=TRANSFORM2_VAL if model_dims == 2 else TRANSFORM_VAL ) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) Loading Loading
zoo/monochrome/train_.py +4 −2 Original line number Diff line number Diff line Loading @@ -28,12 +28,14 @@ _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') _CKPT_PATTERN = re.compile(r'^monochrome-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$') _KNOWN_MODELS = {} _KNOWN_DIMS = {} _KNOWN_DATASETS = {} def _register_model(cls: Type[nn.Module], *args, name=None, **kwargs): name = name or cls.__model_name__ _KNOWN_MODELS[name] = partial(cls, *args, **kwargs) _KNOWN_DIMS[name] = getattr(cls, '__dims__', 1) _register_model(MonochromeAlexNet) Loading Loading @@ -94,7 +96,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio ) session_name = session_name or model_name model_dims = getattr(_KNOWN_MODELS[model_name], '__dims__', 1) model_dims = _KNOWN_DIMS[model_name] _log_dir = os.path.join(_LOG_DIR, session_name) if accelerator.is_local_main_process: Loading Loading @@ -126,7 +128,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio num_workers = num_workers or min(os.cpu_count(), batch_size) train_dataset, test_dataset = random_split_dataset( full_dataset, train_size, test_size, trans_val=TRANSFORM2_VAL if full_dataset.__dims__ == 2 else TRANSFORM_VAL trans_val=TRANSFORM2_VAL if model_dims == 2 else TRANSFORM_VAL ) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) Loading