Commit fafc9e12 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add pixai tagger x, ci skip

parent 74bba37b
Loading
Loading
Loading
Loading
+0 −0

Empty file added.

+87 −0
Original line number Diff line number Diff line
import json
import os.path

import numpy as np
import pandas as pd
from ditk import logging
from hbutils.system import TemporaryDirectory
from hfutils.operate import get_hf_client

from imgutils.preprocess import parse_torchvision_transforms
from zoo.pixai_tagger.tags import load_tags
from .min_script import EndpointHandler


def sync(src_repo: str, dst_repo: str):
    hf_client = get_hf_client()
    if not hf_client.repo_exists(repo_id=dst_repo, repo_type='model'):
        hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True)

    handler = EndpointHandler(repo_id=src_repo)
    with TemporaryDirectory() as upload_dir:
        preprocessor = handler.transform
        preprocessor_file = os.path.join(upload_dir, 'preprocessor.json')
        logging.info(f'Dumping preprocessor:\n{preprocessor}\nto file {preprocessor_file!r}.')
        with open(preprocessor_file, 'w') as f:
            json.dump({
                'stages': parse_torchvision_transforms(handler.transform),
            }, f, sort_keys=True, ensure_ascii=False, indent=4)

        logging.info('Scanning tags ...')
        categories = np.zeros((len(handler.index_to_tag_map),), dtype=np.int32)
        idx = np.array(range(len(categories)))
        categories[idx < handler.gen_tag_count] = 0
        categories[idx >= handler.gen_tag_count] = 4
        df_src_tags = pd.DataFrame({
            'name': [v for _, v in sorted(handler.index_to_tag_map.items())],
            'category': categories,
        })
        df_tags = load_tags(df_src_tags)
        exts = []
        for titem in df_tags.to_dict('records'):
            if titem['category'] == 4:
                if titem['name'] in handler.character_ip_mapping:
                    exts.append(handler.character_ip_mapping[titem['name']])
                else:
                    exts.append([])
            else:
                exts.append([])
        df_tags['ips'] = exts
        logging.info(f'Tags:\n{df_tags}')
        df_tags.to_csv(os.path.join(upload_dir, 'selected_tags.csv'), index=False)

        with open(os.path.join(upload_dir, 'categories.json'), 'w') as f:
            json.dump([
                {
                    "category": 0,
                    "name": "general"
                },
                {
                    "category": 4,
                    "name": "character"
                },
            ], f, sort_keys=True, ensure_ascii=False, indent=4)
        df_th = pd.DataFrame([
            {'category': 0, 'name': 'general', 'threshold': handler.default_general_threshold},
            {'category': 4, 'name': 'character', 'threshold': handler.default_character_threshold},
        ])
        df_th.to_csv(os.path.join(upload_dir, 'thresholds.csv'), index=False)

        handler.model

        os.system(f'tree {upload_dir!r}')
        input()
    # print(df_tags)

    # pprint(handler.index_to_tag_map)
    # pprint(handler.gen_tag_count)
    # pprint(handler.character_tag_count)
    # pprint(handler.character_ip_mapping)


if __name__ == '__main__':
    logging.try_init_root(logging.INFO)
    sync(
        src_repo='pixai-labs/pixai-tagger-v0.9',
        dst_repo='deepghs/pixai-tagger-v0.9-onnx'
    )
+228 −0
Original line number Diff line number Diff line
import base64
import io
import json
import logging
import time
from pathlib import Path
from typing import Any

import requests
import timm
import torch
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.preprocess import parse_torchvision_transforms


class TaggingHead(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))

    def forward(self, x):
        logits = self.head(x)
        probs = torch.nn.functional.sigmoid(logits)
        return probs


def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]:
    with tags_file.open("r", encoding="utf-8") as f:
        tag_info = json.load(f)
    tag_map = tag_info["tag_map"]
    tag_split = tag_info["tag_split"]
    gen_tag_count = tag_split["gen_tag_count"]
    character_tag_count = tag_split["character_tag_count"]
    return tag_map, gen_tag_count, character_tag_count


def get_character_ip_mapping(mapping_file: Path):
    with mapping_file.open("r", encoding="utf-8") as f:
        mapping = json.load(f)
    return mapping


def get_encoder():
    base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3"
    encoder = timm.create_model(base_model_repo, pretrained=False)
    encoder.reset_classifier(0)
    return encoder


def get_decoder():
    decoder = TaggingHead(1024, 13461)
    return decoder


def get_model():
    encoder = get_encoder()
    decoder = get_decoder()
    model = torch.nn.Sequential(encoder, decoder)
    return model


def load_model(weights_file, device):
    model = get_model()
    states_dict = torch.load(weights_file, map_location=device, weights_only=True)
    model.load_state_dict(states_dict)
    model.to(device)
    model.eval()
    return model


def pure_pil_alpha_to_color_v2(
        image: Image.Image, color: tuple[int, int, int] = (255, 255, 255)
) -> Image.Image:
    """
    Convert a PIL image with an alpha channel to a RGB image.
    This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel.
    This function will convert the image to a RGB image, and fill the alpha channel with the given color.
    The alpha channel is the 4th channel of the image.
    """
    image.load()  # needed for split()
    background = Image.new("RGB", image.size, color)
    background.paste(image, mask=image.split()[3])  # 3 is the alpha channel
    return background


def pil_to_rgb(image: Image.Image) -> Image.Image:
    if image.mode == "RGBA":
        image = pure_pil_alpha_to_color_v2(image)
    elif image.mode == "P":
        image = pure_pil_alpha_to_color_v2(image.convert("RGBA"))
    else:
        image = image.convert("RGB")
    return image


class EndpointHandler:
    def __init__(self, repo_id: str = 'pixai-labs/pixai-tagger-v0.9'):
        weights_file = Path(hf_hub_download(
            repo_id=repo_id,
            repo_type='model',
            filename="model_v0.9.pth",
        ))
        tags_file = Path(hf_hub_download(
            repo_id=repo_id,
            repo_type='model',
            filename="tags_v0.9_13k.json",
        ))
        mapping_file = Path(hf_hub_download(
            repo_id=repo_id,
            repo_type='model',
            filename="char_ip_map.json",
        ))

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = load_model(str(weights_file), self.device)
        self.transform = transforms.Compose(
            [
                transforms.Resize((448, 448)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )
        self.fetch_image_timeout = 5.0
        self.default_general_threshold = 0.3
        self.default_character_threshold = 0.85

        tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file)

        # Invert the tag_map for efficient index-to-tag lookups
        self.index_to_tag_map = {v: k for k, v in tag_map.items()}

        self.character_ip_mapping = get_character_ip_mapping(mapping_file)

    def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
        inputs = data.pop("inputs", data)

        fetch_start_time = time.time()
        if isinstance(inputs, Image.Image):
            image = inputs
        elif image_url := inputs.pop("url", None):
            with requests.get(
                    image_url, stream=True, timeout=self.fetch_image_timeout
            ) as res:
                res.raise_for_status()
                image = Image.open(res.raw)
        elif image_base64_encoded := inputs.pop("image", None):
            image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded)))
        else:
            raise ValueError(f"No image or url provided: {data}")
        # remove alpha channel if it exists
        image = pil_to_rgb(image)
        fetch_time = time.time() - fetch_start_time

        parameters = data.pop("parameters", {})
        general_threshold = parameters.pop(
            "general_threshold", self.default_general_threshold
        )
        character_threshold = parameters.pop(
            "character_threshold", self.default_character_threshold
        )

        inference_start_time = time.time()
        with torch.inference_mode():
            # Preprocess image on CPU, then pin memory for faster async transfer
            image_tensor = self.transform(image).unsqueeze(0).pin_memory()

            # Asynchronously move image to GPU
            image_tensor = image_tensor.to(self.device, non_blocking=True)

            # Run model on GPU
            probs = self.model(image_tensor)[0]  # Get probs for the single image

            # Perform thresholding directly on the GPU
            general_mask = probs[: self.gen_tag_count] > general_threshold
            character_mask = probs[self.gen_tag_count:] > character_threshold

            # Get the indices of positive tags on the GPU
            general_indices = general_mask.nonzero(as_tuple=True)[0]
            character_indices = (
                    character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count
            )

            # Combine indices and move the small result tensor to the CPU
            combined_indices = torch.cat((general_indices, character_indices)).cpu()

        inference_time = time.time() - inference_start_time

        post_process_start_time = time.time()

        cur_gen_tags = []
        cur_char_tags = []

        # Use the efficient pre-computed map for lookups
        for i in combined_indices:
            idx = i.item()
            tag = self.index_to_tag_map[idx]
            if idx < self.gen_tag_count:
                cur_gen_tags.append(tag)
            else:
                cur_char_tags.append(tag)

        ip_tags = []
        for tag in cur_char_tags:
            if tag in self.character_ip_mapping:
                ip_tags.extend(self.character_ip_mapping[tag])
        ip_tags = sorted(set(ip_tags))
        post_process_time = time.time() - post_process_start_time

        logging.info(
            f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s"
        )

        return {
            "feature": cur_gen_tags,
            "character": cur_char_tags,
            "ip": ip_tags,
        }


if __name__ == '__main__':
    handler = EndpointHandler()
    print(handler.transform)
    print(parse_torchvision_transforms(handler.transform))
+82 −0
Original line number Diff line number Diff line
from functools import lru_cache

import pandas as pd
from ditk import logging
from hfutils.operate import get_hf_client
from tqdm import tqdm
from waifuc.utils import srequest

from zoo.wd14.tags import _get_tag_by_name, _db_session

_CATEGORY_MAPS = {
    'general': 0,
    'character': 4,
}


@lru_cache()
def _get_rating_count_by_name(tag_name: str):
    session = _db_session()
    logging.info(f'Getting count for {tag_name!r} ...')
    vs = srequest(
        session, 'GET', f'https://danbooru.donmai.us/counts/posts.json',
        params={'tags': f'rating:{tag_name}'}
    ).json()
    logging.info(f'Result of {tag_name!r}: {vs!r}')
    return vs['counts']['posts']


def load_tags(df_src_tags: pd.DataFrame):
    hf_client = get_hf_client()
    df_p_tags = pd.read_csv(hf_client.hf_hub_download(
        repo_id='deepghs/site_tags',
        repo_type='dataset',
        filename='danbooru.donmai.us/tags.csv'
    ))
    logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}')
    d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')}

    rows = []
    src_tags = df_src_tags.to_dict('records')
    for i in tqdm(range(len(src_tags)), desc='Scan Tags'):
        tag_name = src_tags[i]['name']
        category = src_tags[i]['category']
        if (category, tag_name) in d_p_tags:
            tag_id = d_p_tags[(category, tag_name)]['id']
            count = d_p_tags[(category, tag_name)]['post_count']
        elif category < 9:
            logging.warning(f'Cannot find tag {tag_name!r}, category: {category!r}.')
            tag_info = _get_tag_by_name(tag_name)
            if tag_info['name'] != tag_name:
                logging.warning(f'Not found matching tags for {tag_name!r}, will be ignored.')
                tag_id = -1
                count = -1
            else:
                logging.info(f'Tag info found from danbooru - {tag_info!r}.')
                tag_id = tag_info['id']
                count = tag_info['post_count']
                if category != tag_info['category']:
                    logging.warning(f'Category not match for tag {tag_name!r}, '
                                    f'replace category {category!r} --> {tag_info["category"]!r}')
                    category = tag_info['category']
        else:
            logging.warning(f'Unknown tag {tag_name!r} ...')
            tag_id = -1
            count = -1

        rows.append({
            'id': i,
            'tag_id': tag_id,
            'name': tag_name,
            'category': category,
            'count': count,
        })

    df = pd.DataFrame(rows)
    return df


if __name__ == '__main__':
    logging.try_init_root(level=logging.INFO)
    df = load_tags()
    df.to_csv('test_df.csv', index=False)