ST表的Python实现模板(支持泛型类型检查)

由Jeza Chen 发表于 August 7, 2024

参照OIWiki对于ST表的C++实现,实现了一套Python版本的ST表,支持Python类型检查的泛型(Generic)推导功能。ST表是一种用于解决区间查询问题的数据结构,其主要用途是在O(1)时间内查询区间内的最值等可重复贡献的问题。需要使用的时候,直接复制下面的代码到Python文件中即可。

import math

from typing import Callable, Generic, TypeVar

T = TypeVar('T')


class SparseTable(Generic[T]):
    """
    ST表的Python泛型实现
    原理参见: https://oi-wiki.org/ds/sparse-table/
    @note: T为元素类型, 实际使用时无需指定, 会自动推导, 仅用于IDE类型提示

    用法::
    >>> st = SparseTable([0, 1, 2, 3, 4, 5], min)
    >>> assert st.query(0, 4) == 0
    >>> assert st.query(1, 5) == 1
    >>> assert st.query(2, 5) == 2
    >>> # 也可以直接使用下标访问
    >>> assert st[0, 4] == 0
    >>> assert st[1, 5] == 1
    >>> assert st[2, 5] == 2

    >>> st = SparseTable([0, 1, 2, 3, 4, 5], max)
    >>> assert st.query(0, 4) == 4
    >>> assert st.query(1, 5) == 5
    >>> assert st.query(2, 5) == 5
    >>> assert st[0, 4] == 4
    >>> assert st[1, 5] == 5
    >>> assert st[2, 5] == 5

    >>> st = SparseTable(['a', 'b', 'c', 'd', 'e', 'f'], max)
    >>> assert st.query(0, 4) == 'e'
    >>> assert st.query(1, 5) == 'f'
    >>> assert st.query(2, 5) == 'f'
    >>> assert st[0, 4] == 'e'
    >>> assert st[1, 5] == 'f'
    >>> assert st[2, 5] == 'f'

    >>> st = SparseTable([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], max)
    >>> assert st.query(0, 4) == 0.5
    >>> assert st.query(1, 5) == 0.6
    >>> assert st.query(2, 5) == 0.6
    >>> assert st[0, 4] == 0.5
    >>> assert st[1, 5] == 0.6
    >>> assert st[2, 5] == 0.6
    """
    def __init__(self, list_: list[T], op_func: Callable[[T, T], T]):
        """
        @param list_: 原始列表
        @param op_func: 运算函数, 接受两个参数, 返回一个参数
        """
        self._list = list_
        self._op_func = op_func
        elem_type = type(self._list[0]) if self._list else type(None)
        l = len(self._list)
        log2_l = math.floor(math.log2(l))
        self._st = [[elem_type()] * (log2_l + 1) for _ in range(l)]  # type: list[list[T]]
        for i in range(l):
            self._st[i][0] = self._list[i]
        for j in range(1, log2_l + 1):  # 1..log2_l
            range_size = 1 << j
            for i in range(l - range_size + 1):
                # merge
                self._st[i][j] = self._op_func(
                    self._st[i][j - 1],
                    self._st[i + (1 << (j - 1))][j - 1]
                )

    def query(self, l: int, r: int) -> T:
        """ 查询区间 [l, r] 的结果 """
        range_size = r - l + 1
        sub_range = math.floor(math.log2(range_size))
        return self._op_func(self._st[l][sub_range], self._st[r - (1 << sub_range) + 1][sub_range])

    def __len__(self):
        return len(self._list)

    def __getitem__(self, item):
        if not isinstance(item, tuple) or len(item) != 2:
            raise TypeError('SparseTable.__getitem__ only accepts 2-tuple')
        l, r = item
        return self.query(l, r)