Commit b4f06847 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save the training code

parent c77de73b
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -52,3 +52,6 @@ dataset:
	if [ ! -d ${DATASET_DIR}/chafen_arknights ]; then \
		git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \
	fi
	if [ ! -d ${DATASET_DIR}/monochrome_danbooru ]; then \
		git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \
	fi
 No newline at end of file
+0 −0

Empty file added.

+39 −0
Original line number Diff line number Diff line
import torch
import torch.nn as nn


class MonochromeAlexNet(nn.Module):
    def __init__(self, input_channels: int = 3, num_classes=2):
        super(MonochromeAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(input_channels, 96, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2),
            nn.Conv1d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2),
            nn.Conv1d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool1d(6)
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
+55 −0
Original line number Diff line number Diff line
import os

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms

from .encode import image_encode

TRANSFORM = transforms.Compose([
    transforms.Resize(900),
    transforms.RandomCrop(800, padding=150, pad_if_needed=True, padding_mode='reflect'),
    transforms.RandomRotation((-180, 180)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.25, 0.25, 0.15, 0.3),
    transforms.Resize(450),
])


class ImageDirectoryDataset(Dataset):
    def __init__(self, root_dir, label: int = 1, bins: int = 200, transform=TRANSFORM):
        self.root_dir = root_dir
        self.label = label
        self.bins = bins
        self.transform = transform
        self.samples = []
        for file_name in os.listdir(root_dir):
            file_path = os.path.join(root_dir, file_name)
            self.samples.append(file_path)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path = self.samples[idx]
        image = Image.open(file_path).convert('HSV')
        if self.transform:
            image = self.transform(image)
        return image_encode(image, bins=self.bins, normalize=True), torch.tensor(self.label)


class MonochromeDataset(Dataset):
    def __init__(self, root_dir: str, bins: int = 200, transform=TRANSFORM):
        self.monochrome = ImageDirectoryDataset(os.path.join(root_dir, 'monochrome'), 1, bins, transform)
        self.normal = ImageDirectoryDataset(os.path.join(root_dir, 'normal'), 0, bins, transform)

    def __len__(self):
        return len(self.monochrome) + len(self.normal)

    def __getitem__(self, idx):
        if idx < len(self.monochrome):
            return self.monochrome[idx]
        else:
            return self.normal[idx - len(self.monochrome)]
+52 −0
Original line number Diff line number Diff line
from typing import Optional

import numpy as np
import torch
from PIL import ImageFilter
from scipy import signal
from torchvision.transforms.functional import to_tensor

from imgutils.data import load_image, ImageTyping


def np_hist(x, a_min: float = 0.0, a_max: float = 1.0, bins: int = 200):
    x = np.asarray(x)
    edges = torch.linspace(a_min, a_max, bins + 1).numpy()
    cnt, _ = np.histogram(x, bins=edges)

    return torch.from_numpy(cnt / cnt.sum())


def butterworth_filter(r, fc):
    w = fc / (len(r) / 2)  # Normalize the frequency
    b, a = signal.butter(5, w, 'low')
    sr = np.clip(signal.filtfilt(b, a, r), a_min=0.0, a_max=1.0)
    return torch.from_numpy(sr.copy())


def image_encode(image: ImageTyping, bins: int = 200, mf: Optional[int] = 5,
                 maxpixels: int = 20000, fc: Optional[int] = 30, normalize: bool = False):
    image = load_image(image, mode='RGB')
    if image.width * image.height > maxpixels:
        r = (image.width * image.height / maxpixels) ** 0.5
        new_width, new_height = map(lambda x: int(round(x / r)), image.size)
        image = image.resize((new_width, new_height))

    if mf is not None:
        image = image.filter(ImageFilter.MedianFilter(mf))
    image = image.convert('HSV')

    data = to_tensor(image)
    channels = [np_hist(data[i], bins=bins) for i in range(3)]
    if fc is not None:
        channels = [butterworth_filter(ch, fc) for ch in channels]

    dist = torch.stack(channels)
    assert dist.shape == (3, bins)

    if normalize:
        mean = torch.mean(dist, dim=1, keepdim=True)
        std = torch.std(dist, dim=1, keepdim=True)
        dist = (dist - mean) / std

    return dist