-
发表于 2024.07.28
-
线段树的应用题,由于之前没系统实现过线段树,所以这次从头学了一遍线段树的简单实现和应用。该题目就是在一个二维的坐标轴上,不断添加方块,并将每次添加后所有方块的最大高度添加到结果中。本质上,每次添加边长为
sideLength
的方块时,首先需要找到该方块所在x轴区间[left, left + sideLength - 1]
的最大高度h
,然后将该区间的最大高度更新为h + sideLength
。不难得知此类区间更新+查询类型题目可使用线段树解决,每个节点维护一个区间的最大高度,在每次添加方块的时候,首先查询该区间的最大高度,然后更新该区间的最大高度,并将总的最大高度添加到结果中。该题目之所以为Hard,是因为题意中的数据量较大,且
left
的范围较广,如果使用传统的基于数组的实现方法会导致内存溢出,此时可考虑使用哈希表来存储线段树的节点,以减少内存的使用(因为暂时用不到的节点可以先不存储在哈希表中,节点按需创建)。此外,还需要用到延迟标记来实现区间更新的懒惰更新,避免区间范围过大时导致递归很深且内存占用过大。class SegTree: def __init__(self, n: int): self._n = n self._tree = {} # 存储每个节点对应区间的最大值 self._delay = {} # 延迟更新标记 def _delay_update(self, cur_idx: int): """ 处理cur_idx对应区间的节点的延迟更新标记 更新两个子节点的数据,并将更新标记下发到两个子节点中(当遍历到子节点的时候再继续往下下发) """ lc_idx = cur_idx * 2 + 1 rc_idx = cur_idx * 2 + 2 # 当前节点有延迟更新标记(意味着子节点数据还没有更新) # 更新两个子节点的数据 # 并将延迟标记下发到孩子节点中 self._tree[lc_idx] = self._delay[lc_idx] = self._delay[cur_idx] self._tree[rc_idx] = self._delay[rc_idx] = self._delay[cur_idx] # 下发完毕, 删除本节点的延迟更新标记 del self._delay[cur_idx] def _search_helper(self, l: int, r: int, cur_idx: int, s: int, t: int) -> int: """ 递归搜索以获取区间[l, r]的最大值 :param l 待搜索区间的左侧 :param r 待搜索区间的右侧 :param s 当前区间左侧 :param t 当前区间右侧 :param cur_idx 当前区间的节点序号 """ if l <= s and t <= r: # 当前区间[s, t]包含在待搜索区间的[l, r]中,直接返回当前节点所保存的区间[s, t]最大值 return self._tree.get(cur_idx, 0) if t < l or s > r: # [s, t]和[l, r]没有交集, 提前return return 0 if self._delay.get(cur_idx, 0): # 有延迟更新标记, 处理下 self._delay_update(cur_idx) lc_idx = cur_idx * 2 + 1 rc_idx = cur_idx * 2 + 2 mid = (s + t) // 2 # 递归下去搜索 return max( self._search_helper(l, r, lc_idx, s, mid), self._search_helper(l, r, rc_idx, mid + 1, t) ) def search(self, l: int, r: int) -> int: return self._search_helper(l, r, 0, 0, self._n - 1) def _update_helper(self, l: int, r: int, val: int, cur_idx: int, s: int, t: int): """ 递归更新区间[l, r]的最大值为val :param l 待更新区间的左侧 :param r 待更新区间的右侧 :param val 更新后的值 :param s 当前区间左侧 :param t 当前区间右侧 :param cur_idx 当前区间的节点序号 """ if t < l or s > r: # [s, t]和[l, r]没有交集, 提前return return if l <= s and t <= r: # [s, t]包含在[l, r]中,则先**仅对该节点**进行更新操作,并设置延迟更新标记 self._tree[cur_idx] = self._delay[cur_idx] = val return if self._delay.get(cur_idx, 0): # 更新该节点前,如果该节点有延迟更新标记(上轮更新导致的), 先处理上一轮的延迟更新后,再进行本轮的更新 self._delay_update(cur_idx) lc_idx = cur_idx * 2 + 1 rc_idx = cur_idx * 2 + 2 mid = (s + t) // 2 self._update_helper(l, r, val, lc_idx, s, mid) self._update_helper(l, r, val, rc_idx, mid + 1, t) # 更新操作完毕后,也需要更新本节点的值(因为区间有交集),做法为取两个子节点中的最大值即可 self._tree[cur_idx] = max(self._tree.get(lc_idx, 0), self._tree.get(rc_idx, 0)) def update(self, l: int, r: int, val: int): self._update_helper(l, r, val, 0, 0, self._n - 1) def root_val(self): """ 返回整个数组的最大值 """ return self._tree[0] class Solution: def fallingSquares(self, positions: List[List[int]]) -> List[int]: r_bound = max(left + side for left, side in positions) seg_tree = SegTree(r_bound + 1) ans = [] for left, side in positions: cur_max_height = seg_tree.search(left, left + side - 1) seg_tree.update(left, left + side - 1, cur_max_height + side) ans.append(seg_tree.root_val()) return ans
MLE的做法,因为没有做延迟标记,导致递归深度过大,内存占用过大。
class SegTree: def __init__(self, n: int): self._n = n self._tree = {} def _search_helper(self, l: int, r: int, cur_idx: int, s: int, t: int) -> int: if l <= s and t <= r: return self._tree.get(cur_idx, 0) if t < l or s > r: return 0 mid = (s + t) // 2 return max( self._search_helper(l, r, 2 * cur_idx + 1, s, mid), self._search_helper(l, r, 2 * cur_idx + 2, mid + 1, t) ) def search(self, l: int, r: int) -> int: return self._search_helper(l, r, 0, 0, self._n - 1) def _update_helper(self, l: int, r: int, val: int, cur_idx: int, s: int, t: int): if t < l or s > r: return if s == t: # !!! 注意这里直至叶子节点才更新 self._tree[cur_idx] = val return mid = (s + t) // 2 self._update_helper(l, r, val, 2 * cur_idx + 1, s, mid) self._update_helper(l, r, val, 2 * cur_idx + 2, mid + 1, t) self._tree[cur_idx] = max(self._tree.get(cur_idx * 2 + 1, 0), self._tree.get(cur_idx * 2 + 2, 0)) def update(self, l: int, r: int, val: int): self._update_helper(l, r, val, 0, 0, self._n - 1) def root_val(self): return self._tree[0] class Solution: def fallingSquares(self, positions: List[List[int]]) -> List[int]: r_bound = max(left + side for left, side in positions) seg_tree = SegTree(r_bound + 1) ans = [] for left, side in positions: cur_max_height = seg_tree.search(left, left + side - 1) seg_tree.update(left, left + side - 1, cur_max_height + side) ans.append(seg_tree.root_val()) return ans
- LC 题目链接
-