最近最少使用(LRU)缓存函数装饰器-源码

babyfengfjx / 2024-04-18 / 原文

#####################################################################

 最近最少使用(LRU)缓存函数装饰器

######################################################################

_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])

class _HashedSeq(list):
    """这个类保证了每个元素的hash()最多只被调用一次。
    这很重要,因为lru_cache()在缓存未命中时会多次哈希键。

    """

    __slots__ = 'hashvalue'

    def __init__(self, tup, hash=hash):
        self[:] = tup
        self.hashvalue = hash(tup)

    def __hash__(self):
        return self.hashvalue

def _make_key(args, kwds, typed,
             kwd_mark = (object(),),
             fasttypes = {int, str},
             tuple=tuple, type=type, len=len):
    """从可选的类型化的位置参数和关键字参数中创建一个缓存键

    键的构造方式尽可能扁平而不是嵌套结构,这样可以节省内存。

    如果只有一个参数,并且已知其数据类型可以缓存其哈希值,
    那么该参数将被返回而不需要包装器。这节省了空间并提高了查找速度。

    """
    # 下面的代码都依赖于kwds保留用户输入的顺序。
    # 以前,我们在循环之前对kwds进行了排序()。新的方法是*快得多*;
    # 然而,这意味着f(x=1, y=2)现在将被视为与f(y=2, x=1)不同的调用,
    # 这将分别被缓存。

    key = args
    if kwds:
        key += kwd_mark
        for item in kwds.items():
            key += item
    if typed:
        key += tuple(type(v) for v in args)
        if kwds:
            key += tuple(type(v) for v in kwds.values())
    elif len(key) == 1 and type(key[0]) in fasttypes:
        return key[0]
    return _HashedSeq(key)

def lru_cache(maxsize=128, typed=False):
    """最近最少使用(LRU)缓存装饰器。

    如果将*maxsize*设置为None,将禁用LRU特性,并且缓存可以无限制地增长。

    如果*typed*为True,将分别缓存不同类型参数的调用结果。
    例如,f(3.0)和f(3)将被视为两个不同的调用,分别有各自的结果。

    缓存函数的参数必须是可哈希的。

    使用f.cache_info()查看缓存统计信息命名元组(hits, misses, maxsize, currsize)。
    使用f.cache_clear()清除缓存和统计信息。
    使用f.__wrapped__访问底层函数。

    参见:https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)

    """

    # 用户应该只通过其公共API访问lru_cache:
    #       cache_info, cache_clear, 和 f.__wrapped__
    # lru_cache的内部实现被封装以保证线程安全,
    # 并允许实现发生变化(包括可能的C语言版本)。

    if isinstance(maxsize, int):
        # 负的maxsize被视为0
        if maxsize < 0:
            maxsize = 0
    elif callable(maxsize) and isinstance(typed, bool):
        # 用户函数通过maxsize参数直接传递
        user_function, maxsize = maxsize, 128
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)
    elif maxsize is not None:
        raise TypeError(
            '预期第一个参数为整数、可调用对象或None')

    def decorating_function(user_function):
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)

    return decorating_function

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    # 所有lru缓存实例共享的常量:
    sentinel = object()          # 用于信号缓存未命中的唯一对象
    make_key = _make_key         # 从函数参数构建一个键
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # 链接字段的名称

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # 绑定方法来查找一个键或返回None
    cache_len = cache.__len__  # 不调用len()函数来获取缓存大小
    lock = RLock()           # 因为链表更新不是线程安全的
    root = []                # 循环双向链表的根
    root[:] = [root, root, None, None]     # 初始化为指向自身

    if maxsize == 0:

        def wrapper(*args, **kwds):
            # 不进行缓存 -- 只是统计更新
            nonlocal misses
            misses += 1
            result = user_function(*args, **kwds)
            return result

    elif maxsize is None:

        def wrapper(*args, **kwds):
            # 无顺序或大小限制的简单缓存
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            misses += 1
            result = user_function(*args, **kwds)
            cache[key] = result
            return result

    else:

        def wrapper(*args, **kwds):
            # 按最近访问跟踪的尺寸限制缓存
            nonlocal root, hits, misses, full
            key = make_key(args, kwds, typed)
            with lock:
                link = cache_get(key)
                if link is not None:
                    # 将链接移动到循环队列的前端
                    link_prev, link_next, _key, result = link
                    link_prev[NEXT] = link_next
                    link_next[PREV] = link_prev
                    last = root[PREV]
                    last[NEXT] = root[PREV] = link
                    link[PREV] = last
                    link[NEXT] = root
                    hits += 1
                    return result
                misses += 1
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # 来到这里意味着在释放锁的同时,这个相同的键被添加到了缓存中。
                    # 由于链接更新已经完成,我们只需要返回计算结果并更新未命中数。
                    pass
                elif full:
                    # 使用旧的根来存储新的键和结果。
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result
                    # 清空最老的链接并使其成为新的根。
                    # 保留对旧键和旧结果的引用,
                    # 以防止在更新期间它们的引用计数降至零。
                    # 这将防止在我们还调整链接时运行潜在的任意对象清理代码(例如 __del__)。
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None
                    # 现在更新缓存字典。
                    del cache[oldkey]
                    # 将可能重新进入的cache[key]赋值
                    # 放在最后,因为根和链接已经被置于一致的状态。
                    cache[key] = oldroot
                else:
                    # 将结果放入队列前端的新链接中。
                    last = root[PREV]
                    link = [last, root, key, result]
                    last[NEXT] = root[PREV] = cache[key] = link
                    # 使用cache_len绑定方法而不是len()函数
                    # 该函数本身可能被包装在lru_cache中。
                    full = (cache_len() >= maxsize)
            return result

    def cache_info():
        """报告缓存统计信息"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """清除缓存和缓存统计信息"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

try:
    from _functools import _lru_cache_wrapper
except ImportError:
    pass