Commit 2f1c86ca authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): optimize training

parent 2a39b318
Loading
Loading
Loading
Loading
+8 −6
Original line number Diff line number Diff line
import os
from typing import Optional

import torch
from PIL import Image
@@ -13,16 +14,17 @@ TRANSFORM = transforms.Compose([
    transforms.RandomRotation((-180, 180)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.25, 0.25, 0.15, 0.3),
    transforms.ColorJitter(0.10, 0.10, 0.10, 0.10),
    transforms.Resize(450),
])


class ImageDirectoryDataset(Dataset):
    def __init__(self, root_dir, label: int = 1, bins: int = 200, transform=TRANSFORM):
    def __init__(self, root_dir, label: int = 1, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM):
        self.root_dir = root_dir
        self.label = label
        self.bins = bins
        self.fc = fc
        self.transform = transform
        self.samples = []
        for file_name in os.listdir(root_dir):
@@ -37,13 +39,13 @@ class ImageDirectoryDataset(Dataset):
        image = Image.open(file_path).convert('HSV')
        if self.transform:
            image = self.transform(image)
        return image_encode(image, bins=self.bins, normalize=True), torch.tensor(self.label)
        return image_encode(image, bins=self.bins, fc=self.fc, normalize=True), torch.tensor(self.label)


class MonochromeDataset(Dataset):
    def __init__(self, root_dir: str, bins: int = 200, transform=TRANSFORM):
        self.monochrome = ImageDirectoryDataset(os.path.join(root_dir, 'monochrome'), 1, bins, transform)
        self.normal = ImageDirectoryDataset(os.path.join(root_dir, 'normal'), 0, bins, transform)
    def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM):
        self.monochrome = ImageDirectoryDataset(os.path.join(root_dir, 'monochrome'), 1, bins, fc, transform)
        self.normal = ImageDirectoryDataset(os.path.join(root_dir, 'normal'), 0, bins, fc, transform)

    def __len__(self):
        return len(self.monochrome) + len(self.normal)
+8 −6
Original line number Diff line number Diff line
@@ -67,10 +67,12 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
        return None


def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float = 0.8,
          batch_size: int = 4, feature_bins: int = 400, max_epochs: int = 500, learning_rate: float = 0.001,
def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100,
          max_epochs: int = 500, learning_rate: float = 0.001,
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
    _log_dir = os.path.join(_LOG_DIR, model_name)
    session_name = session_name or model_name
    _log_dir = os.path.join(_LOG_DIR, session_name)
    os.makedirs(_log_dir, exist_ok=True)
    os.makedirs(_CKPT_DIR, exist_ok=True)
    writer = SummaryWriter(_log_dir)
@@ -81,7 +83,7 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
    })

    # Initialize dataset
    full_dataset = MonochromeDataset(dataset_dir, bins=feature_bins)
    full_dataset = MonochromeDataset(dataset_dir, bins=feature_bins, fc=fc)
    dataset_size = len(full_dataset)
    train_size = int(train_ratio * dataset_size)
    test_size = dataset_size - train_size
@@ -95,7 +97,7 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
    model = _KNOWN_MODELS[model_name]().float()
    # model = MonochromeAlexNet().float()
    if from_ckpt is None:
        from_ckpt = _find_latest_ckpt(model_name)
        from_ckpt = _find_latest_ckpt(session_name)
    previous_epoch = _ckpt_epoch(from_ckpt) or 0
    if from_ckpt:
        logging.info(f'Load checkpoint from {from_ckpt!r}.')
@@ -159,6 +161,6 @@ def train(dataset_dir: str, from_ckpt: Optional[str] = None, train_ratio: float
            writer.add_scalar('test/accuracy', test_accuracy, epoch)

        if epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{model_name}-{epoch}.ckpt')
            current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{session_name}-{epoch}.ckpt')
            torch.save(model.state_dict(), current_ckpt_file)
            logging.info(f'Saved to {current_ckpt_file!r}.')