Commit d6d5463a authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): optimize ts_lru_cache

parent effc3784
Loading
Loading
Loading
Loading
+27 −4
Original line number Diff line number Diff line
@@ -13,14 +13,31 @@ Usage:
    ...     # Some expensive computation
    ...     return x + y
"""

import os
import threading
from collections import defaultdict
from functools import lru_cache, wraps

from typing import Literal

__all__ = ['ts_lru_cache']

LevelTyping = Literal['global', 'process', 'thread']


def ts_lru_cache(**options):
def _get_context_key(level: LevelTyping = 'global'):
    if level == 'global':
        return None
    elif level == 'process':
        return os.getpid()
    elif level == 'thread':
        return os.getpid(), threading.get_ident()
    else:
        raise ValueError(f'Invalid cache level, '
                         f'\'global\', \'process\' or \'thread\' expected but {level!r} found.')


def ts_lru_cache(level: LevelTyping = 'global', **options):
    """
    A thread-safe version of the lru_cache decorator.

@@ -45,6 +62,7 @@ def ts_lru_cache(**options):
        due to lock acquisition. Use it when thread-safety is more critical than
        maximum performance in multithreaded scenarios.
    """
    _ = _get_context_key(level)

    def _decorator(func):
        """
@@ -56,9 +74,10 @@ def ts_lru_cache(**options):
        :return: The wrapped function with thread-safe caching.
        :rtype: function
        """

        @lru_cache(**options)
        @wraps(func)
        def _cached_func(*args, **kwargs):
        def _cached_func(*args, __context_key=None, **kwargs):
            """
            Cached version of the original function.

@@ -69,6 +88,7 @@ def ts_lru_cache(**options):
            """
            return func(*args, **kwargs)

        lock_pool = defaultdict(threading.Lock)
        lock = threading.Lock()

        @wraps(_cached_func)
@@ -84,8 +104,11 @@ def ts_lru_cache(**options):

            :return: The result of the cached function call.
            """
            context_key = _get_context_key(level=level)
            with lock:
                return _cached_func(*args, **kwargs)
                _context_lock = lock_pool[context_key]
            with _context_lock:
                return _cached_func(*args, __context_key=context_key, **kwargs)

        # Preserve cache_info and cache_clear methods if they exist
        if hasattr(_cached_func, 'cache_info'):