-
发表于 2024.05.09
-
三色标记法+DFS,使用一个数组维护每个节点可能的三种状态:未遍历,正在遍历,遍历完毕。然后使用DFS计算每个节点是否为安全节点:如果其为终端节点,即也是安全节点;否则,遍历所有的路径,如果所有的路径经过的节点皆为安全节点,则该节点也为安全节点;如果碰上一个正在遍历的节点(在递归栈中),则存在环,说明该路径是走不到终端节点的。
这道题其实也是一个比较典型的拓扑排序应用题(毕竟终端节点定义为出度为0,很符合拓扑排序的特征),可以将这个图“翻转”起来,即调转所有的边的方向,然后应用拓扑排序将入度为0的节点添加到结果中并剥离,直至图中只剩下环。
class Solution: def eventualSafeNodes(self, graph: List[List[int]]) -> List[int]: n = len(graph) # 寻找终端节点 terminal_nodes = set() for u, adj_list in enumerate(graph): if not adj_list: terminal_nodes.add(u) safe_nodes = set(terminal_nodes) # === 原来这就是三色标记法 === # 其实就是用来维护三种可能的状态嘛 visit_status = [0] * n # 0: 从来未访问 1: 正在访问中 2: 访问完毕 def dfs(u: int) -> bool: """ 遍历节点u的所有路径, 最终返回u是不是安全节点 """ if u in terminal_nodes: # 已经到达终端节点, 返回True return True if visit_status[u] == 1: # 发现环, 说明这条路径是走不通的 return False if visit_status[u] == 2: # 如果遍历的路径经过一个已经遍历的节点 return u in safe_nodes # 直接检查其是否安全节点即可 visit_status[u] = 1 # 标记正在访问 is_safe = True for v in graph[u]: if not dfs(v): is_safe = False break visit_status[u] = 2 # 标记访问完毕 if is_safe: safe_nodes.add(u) return is_safe for u in range(n): if u not in terminal_nodes and visit_status[u] == 0: dfs(u) return sorted(safe_nodes)
一个更简洁一点的方法,去掉
terminal_nodes
(反正遍历的时候,终端节点本就没有出路,直接走dfs逻辑就好了)和safe_nodes
,直接使用状态1
共同表示正在访问中/非安全节点,状态2
表示已经访问完毕且是安全节点。class Solution: def eventualSafeNodes(self, graph: List[List[int]]) -> List[int]: n = len(graph) # === 原来这就是三色标记法 === # 其实就是用来维护三种可能的状态嘛 visit_status = [0] * n # 0: 从来未访问 1: 正在访问中 or 非安全节点 2: 安全节点 def dfs(u: int) -> bool: """ 遍历节点u的所有路径, 最终返回u是不是安全节点 """ if visit_status[u] == 1: return False if visit_status[u] == 2: # 如果遍历的路径经过一个已经遍历的节点 return True visit_status[u] = 1 # 标记正在访问 for v in graph[u]: if not dfs(v): # 直接返回就行了 # 此时的状态1表示为非安全节点了! return False visit_status[u] = 2 # 标记访问完毕 return True ans = [] for u in range(n): if visit_status[u] == 0: dfs(u) if visit_status[u] == 2: ans.append(u) return ans
应用拓扑排序的版本,注意寻找入度为0的节点不要循环,会超时。维护一个队列来保存入度为0的节点。
from collections import deque class Solution: def eventualSafeNodes(self, graph: List[List[int]]) -> List[int]: n = len(graph) rev_graph = [[] for _ in range(n)] for u, adj_list in enumerate(graph): for v in adj_list: rev_graph[v].append(u) q = deque() in_deg = [0] * n for u in range(n): in_deg[u] = len(graph[u]) if in_deg[u] == 0: q.append(u) ans = set() while q: zero_deg_node = q.popleft() ans.add(zero_deg_node) for v in rev_graph[zero_deg_node]: in_deg[v] -= 1 if not in_deg[v]: q.append(v) return sorted(ans)
- LC 题目链接
-