Commit 324c8122 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest for url functions

parent cfb10ecc
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -8,3 +8,4 @@ from .decode import *
from .encode import *
from .image import *
from .layer import *
from .url import *
 No newline at end of file

imgutils/data/url.py

0 → 100644
+88 −0
Original line number Diff line number Diff line
import io
from typing import Optional

import pyrfc6266
from PIL import Image
from hbutils.system import urlsplit
from huggingface_hub import get_session
from tqdm import tqdm
from urlobject import URLObject

__all__ = [
    'download_image_from_url',
    'is_http_url',
]


def download_image_from_url(url: str, silent: bool = False, expected_size: Optional[int] = None,
                            **kwargs) -> Image.Image:
    if _is_github_url(url):
        url = _process_github_url_for_downloading(url)
    elif _is_hf_url(url):
        url = _process_hf_url_for_downloading(url)

    session = get_session()
    with session.get(url, stream=True, allow_redirects=True, **kwargs) as response:
        expected_size = expected_size or response.headers.get('Content-Length', None)
        expected_size = int(expected_size) if expected_size is not None else expected_size
        filename = None
        if response.headers.get('Content-Disposition'):
            filename = pyrfc6266.parse_filename(response.headers.get('Content-Disposition'))
        filename = filename or urlsplit(url).filename

        with io.BytesIO() as bf:
            with tqdm(total=expected_size, unit='B', unit_scale=True, unit_divisor=1024,
                      desc=filename, disable=silent) as pbar:
                for chunk in response.iter_content(chunk_size=1024):
                    bf.write(chunk)
                    pbar.update(len(chunk))

            bf.seek(0)
            image = Image.open(bf)
            image.load()
            return image


def is_http_url(url: str) -> bool:
    if not isinstance(url, str):
        return False

    split = urlsplit(url)
    return split.scheme == 'http' or split.scheme == 'https'


_GITHUB_SUFFIX = {('github', 'com')}


def _is_github_url(url: str) -> bool:
    # assume that is_http_url(url) is True
    return tuple(urlsplit(url).host.split('.')[-2:]) in _GITHUB_SUFFIX


def _process_github_url_for_downloading(url: str) -> str:
    return str(URLObject(url).with_query('raw=True'))


_HF_SUFFIX = {('hf', 'co'), ('huggingface', 'co')}


def _is_hf_url(url: str) -> bool:
    # assume that is_http_url(url) is True
    return tuple(urlsplit(url).host.split('.')[-2:]) in _HF_SUFFIX


def _process_hf_url_for_downloading(url: str) -> str:
    split = urlsplit(url)
    segments = split.path_segments
    if len(segments) >= 2 and (segments[1] == 'datasets' or segments[1] == 'spaces'):
        position = 4
    else:
        position = 3

    if len(segments) > position and segments[position] == 'blob':
        segments = [*segments[:position], 'resolve', *segments[position + 1:]]
    elif len(segments) > position and segments[position] == 'resolve':
        pass
    else:
        raise ValueError(f'Unsupported huggingface URL - {url!r}.')
    return f'{split.scheme}://{split.host}{"/".join(segments)}'
+3 −1
Original line number Diff line number Diff line
@@ -18,3 +18,5 @@ bchlib>=1.0.0,!=2.0.0,!=2.0.1,!=2.1.0,!=2.1.1,!=2.1.2
piexif
tokenizers>=0.20.0; python_version >= '3.9'
tokenizers>=0.20.0,<0.21; python_version < '3.9'
pyrfc6266>=1
urlobject>=2
 No newline at end of file

test/data/test_url.py

0 → 100644
+76 −0
Original line number Diff line number Diff line
import pytest

from imgutils.data import is_http_url, download_image_from_url, load_image
from imgutils.data.url import _is_github_url, _process_github_url_for_downloading, _is_hf_url, \
    _process_hf_url_for_downloading
from test.testings import get_testfile


@pytest.mark.unittest
class TestDataURL:
    @pytest.mark.parametrize(['url', 'local_image'], [
        (
                'https://github.com/deepghs/imgutils/blob/main/test/testfile/nian_640.png',
                ('nian_640.png',)
        ),
        (
                'https://huggingface.co/deepghs/eattach_monochrome_experiments/blob/main/mlp_layer1_seed1/plot_confusion.png',
                ('plot_confusion.png',)
        )
    ])
    def test_download_image_from_url(self, url, local_image, image_diff):
        local_image_file = get_testfile(*local_image)
        actual_image = download_image_from_url(url)
        expected_image = load_image(local_image_file, mode='RGB', force_background='white')
        assert image_diff(
            load_image(actual_image, mode='RGB', force_background='white'),
            expected_image,
            throw_exception=False,
        ) < 1e-2

    def test_is_http_url(self):
        assert is_http_url('http://example.com')
        assert is_http_url('https://example.com')
        assert not is_http_url('ftp://example.com')
        assert not is_http_url('not_a_url')
        assert not is_http_url(123)

    def test_is_github_url(self):
        assert _is_github_url('https://github.com/user/repo')
        assert not _is_github_url('https://gitlab.com/user/repo')

    def test_process_github_url_for_downloading(self):
        url = 'https://github.com/user/repo'
        result = _process_github_url_for_downloading(url)
        assert result == 'https://github.com/user/repo?raw=True'

    def test_is_hf_url(self):
        assert _is_hf_url('https://huggingface.co/user/repo')
        assert _is_hf_url('https://hf.co/user/repo')
        assert not _is_hf_url('https://example.com/user/repo')

    @pytest.mark.parametrize("url, expected", [
        ('https://huggingface.co/datasets/user/repo/blob/main/file.txt',
         'https://huggingface.co/datasets/user/repo/resolve/main/file.txt'),
        ('https://huggingface.co/user/repo/blob/main/file.txt',
         'https://huggingface.co/user/repo/resolve/main/file.txt'),
        ('https://huggingface.co/spaces/user/repo/blob/main/file.txt',
         'https://huggingface.co/spaces/user/repo/resolve/main/file.txt'),
        ('https://huggingface.co/user/repo/resolve/main/file.txt',
         'https://huggingface.co/user/repo/resolve/main/file.txt'),

        ('https://hf.co/datasets/user/repo/blob/main/file.txt',
         'https://hf.co/datasets/user/repo/resolve/main/file.txt'),
        ('https://hf.co/user/repo/blob/main/file.txt',
         'https://hf.co/user/repo/resolve/main/file.txt'),
        ('https://hf.co/spaces/user/repo/blob/main/file.txt',
         'https://hf.co/spaces/user/repo/resolve/main/file.txt'),
        ('https://hf.co/user/repo/resolve/main/file.txt',
         'https://hf.co/user/repo/resolve/main/file.txt'),
    ])
    def test_process_hf_url_for_downloading(self, url, expected):
        assert _process_hf_url_for_downloading(url) == expected

    def test_process_hf_url_for_downloading_invalid(self):
        with pytest.raises(ValueError, match="Unsupported huggingface URL"):
            _process_hf_url_for_downloading('https://huggingface.co/user/repo/invalid/path')
+21.2 KiB
Loading image diff...