Commit 5f229039 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add global seed

parent 4236b412
Loading
Loading
Loading
Loading
+23 −4
Original line number Diff line number Diff line
from functools import lru_cache
from typing import Optional
from typing import Optional, Tuple

import numpy as np
from PIL import Image, ImageFilter
from PIL.Image import Resampling
from huggingface_hub import hf_hub_download
from scipy import signal

from ..data import ImageTyping, load_image
from ..data import ImageTyping, load_image, rgb_encode
from ..utils import open_onnx_model

__all__ = [
@@ -14,7 +15,8 @@ __all__ = [
    'is_monochrome',
]

_DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
# _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
_DEFAULT_MONOCHROME_CKPT = 'monochrome-levit_d0.2-500.onnx'


@lru_cache()
@@ -65,9 +67,26 @@ def _hsv_encode(image: Image.Image, feature_bins: int = 180, mf: Optional[int] =
    return dist


def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
               normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize(size, Resampling.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data


def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float:
    image = load_image(image, mode='RGB')
    input_data = _hsv_encode(image).astype(np.float32)
    # input_data = _hsv_encode(image).astype(np.float32)
    input_data = _2d_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data})
    return float(output_data[0][1])
+5 −1
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from typing import Optional, Type
import torch
from accelerate import Accelerator
from ditk import logging
from hbutils.random import global_seed
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
@@ -95,7 +96,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 180, fc: Optional[int] = 75,
          max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3, preference: float = 0.0,
          num_workers: Optional[int] = 8, save_per_epoch: int = 10, eval_epoch: int = 5,
          model_name: str = 'alexnet'):
          model_name: str = 'alexnet', seed: Optional[int] = 0):
    if seed is not None:
        global_seed(seed)

    accelerator = Accelerator(
        # mixed_precision=self.cfgs.mixed_precision,
        step_scheduler_with_optimizer=False,