Commit 434a7bb8 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): prepare to upload new model

parent 2f4f0770
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ from .train_ import _KNOWN_MODELS
from ..utils import GLOBAL_CONTEXT_SETTINGS
from ..utils import print_version as _origin_print_version

print_version = partial(_origin_print_version, 'zoo.lpips')
print_version = partial(_origin_print_version, 'zoo.monochrome')


@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS})
@@ -41,6 +41,7 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str):

_KNOWN_CKPTS: List[Tuple[str, str, int]] = [
    ('monochrome-alexnet_plus-320.ckpt', 'alexnet', 256),
    ('monochrome-alexnet_plus-500.ckpt', 'alexnet', 256),
]


+6 −0
Original line number Diff line number Diff line
@@ -36,6 +36,12 @@ class ImageDirectoryDataset(Dataset):

    def __getitem__(self, idx):
        file_path = self.samples[idx]
        # ATTENTION: DO NOT REMOVE THIS CONVERT, THIS IS IMPORTANT!!!
        # In torchvision.transforms.functional.pad, the RGB mode will be used when your input image
        # have 3 channels (no matter RGB, LAB or HSV), so the transformed image which actually passed into
        # model should be processed like this:
        # image = Image.fromarray(np.asarray(image.convert("HSV")), mode='RGB')
        # and then use `image_encode` to encode.
        image = Image.open(file_path).convert('HSV')
        if self.transform:
            image = self.transform(image)
+6 −12
Original line number Diff line number Diff line
@@ -71,7 +71,7 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100,
          max_epochs: int = 500, learning_rate: LRTyping = 0.001,
          weight_decay: float = 1e-4, num_workers: Optional[int] = None,
          weight_decay: float = 1e-2, num_workers: Optional[int] = None,
          device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'):
    session_name = session_name or model_name
    _log_dir = os.path.join(_LOG_DIR, session_name)
@@ -104,7 +104,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    previous_epoch = _ckpt_epoch(from_ckpt) or 0
    if from_ckpt:
        logging.info(f'Load checkpoint from {from_ckpt!r}.')
        model.load_state_dict(torch.load(from_ckpt))
        model.load_state_dict(torch.load(from_ckpt, map_location='cpu'))
    else:
        logging.info(f'No checkpoint found, new model will be used.')

@@ -142,11 +142,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        with torch.no_grad():
            train_correct = 0
            for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
                inputs = inputs.float()
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                inputs = inputs.float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()

@@ -156,11 +153,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

            test_correct = 0
            for i, (inputs, labels) in enumerate(tqdm(test_dataloader)):
                inputs = inputs.float()
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                inputs = inputs.float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()