参照OIWiki对于ST表的C++实现,实现了一套Python版本的ST表,支持Python类型检查的泛型(Generic)推导功能。ST表是一种用于解决区间查询问题的数据结构,其主要用途是在O(1)
时间内查询区间内的最值等可重复贡献的问题。需要使用的时候,直接复制下面的代码到Python文件中即可。
import math
from typing import Callable, Generic, TypeVar
T = TypeVar('T')
class SparseTable(Generic[T]):
"""
ST表的Python泛型实现
原理参见: https://oi-wiki.org/ds/sparse-table/
@note: T为元素类型, 实际使用时无需指定, 会自动推导, 仅用于IDE类型提示
用法::
>>> st = SparseTable([0, 1, 2, 3, 4, 5], min)
>>> assert st.query(0, 4) == 0
>>> assert st.query(1, 5) == 1
>>> assert st.query(2, 5) == 2
>>> # 也可以直接使用下标访问
>>> assert st[0, 4] == 0
>>> assert st[1, 5] == 1
>>> assert st[2, 5] == 2
>>> st = SparseTable([0, 1, 2, 3, 4, 5], max)
>>> assert st.query(0, 4) == 4
>>> assert st.query(1, 5) == 5
>>> assert st.query(2, 5) == 5
>>> assert st[0, 4] == 4
>>> assert st[1, 5] == 5
>>> assert st[2, 5] == 5
>>> st = SparseTable(['a', 'b', 'c', 'd', 'e', 'f'], max)
>>> assert st.query(0, 4) == 'e'
>>> assert st.query(1, 5) == 'f'
>>> assert st.query(2, 5) == 'f'
>>> assert st[0, 4] == 'e'
>>> assert st[1, 5] == 'f'
>>> assert st[2, 5] == 'f'
>>> st = SparseTable([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], max)
>>> assert st.query(0, 4) == 0.5
>>> assert st.query(1, 5) == 0.6
>>> assert st.query(2, 5) == 0.6
>>> assert st[0, 4] == 0.5
>>> assert st[1, 5] == 0.6
>>> assert st[2, 5] == 0.6
"""
def __init__(self, list_: list[T], op_func: Callable[[T, T], T]):
"""
@param list_: 原始列表
@param op_func: 运算函数, 接受两个参数, 返回一个参数
"""
self._list = list_
self._op_func = op_func
elem_type = type(self._list[0]) if self._list else type(None)
l = len(self._list)
log2_l = math.floor(math.log2(l))
self._st = [[elem_type()] * (log2_l + 1) for _ in range(l)] # type: list[list[T]]
for i in range(l):
self._st[i][0] = self._list[i]
for j in range(1, log2_l + 1): # 1..log2_l
range_size = 1 << j
for i in range(l - range_size + 1):
# merge
self._st[i][j] = self._op_func(
self._st[i][j - 1],
self._st[i + (1 << (j - 1))][j - 1]
)
def query(self, l: int, r: int) -> T:
""" 查询区间 [l, r] 的结果 """
range_size = r - l + 1
sub_range = math.floor(math.log2(range_size))
return self._op_func(self._st[l][sub_range], self._st[r - (1 << sub_range) + 1][sub_range])
def __len__(self):
return len(self._list)
def __getitem__(self, item):
if not isinstance(item, tuple) or len(item) != 2:
raise TypeError('SparseTable.__getitem__ only accepts 2-tuple')
l, r = item
return self.query(l, r)