Commit fff9d2c2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix bug on dims

parent 6cb11538
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -28,12 +28,14 @@ _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts')
_CKPT_PATTERN = re.compile(r'^monochrome-(?P<name>[a-zA-Z\d_\-]+)-(?P<epoch>\d+)\.ckpt$')

_KNOWN_MODELS = {}
_KNOWN_DIMS = {}
_KNOWN_DATASETS = {}


def _register_model(cls: Type[nn.Module], *args, name=None, **kwargs):
    name = name or cls.__model_name__
    _KNOWN_MODELS[name] = partial(cls, *args, **kwargs)
    _KNOWN_DIMS[name] = getattr(cls, '__dims__', 1)


_register_model(MonochromeAlexNet)
@@ -94,7 +96,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    )

    session_name = session_name or model_name
    model_dims = getattr(_KNOWN_MODELS[model_name], '__dims__', 1)
    model_dims = _KNOWN_DIMS[model_name]
    _log_dir = os.path.join(_LOG_DIR, session_name)

    if accelerator.is_local_main_process:
@@ -126,7 +128,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    num_workers = num_workers or min(os.cpu_count(), batch_size)
    train_dataset, test_dataset = random_split_dataset(
        full_dataset, train_size, test_size,
        trans_val=TRANSFORM2_VAL if full_dataset.__dims__ == 2 else TRANSFORM_VAL
        trans_val=TRANSFORM2_VAL if model_dims == 2 else TRANSFORM_VAL
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                  drop_last=True)