Python实战:用树形DP与贪心优化解决PTA完美树问题
第一次看到PTA L3-035这道题时,我被"完美树"的概念吸引了——如何用最小代价调整节点颜色,使得每棵子树的黑白节点数差不超过1?这不仅是算法竞赛中的经典问题,更是学习树形动态规划(DP)与贪心算法结合的绝佳案例。本文将带你从零开始,用Python实现这个问题的完整解决方案,特别关注如何将C++竞赛代码转化为清晰、模块化的Python工程实现。
1. 问题理解与建模
完美树问题的核心在于理解题目要求与约束条件。给定一棵有N个节点的树,每个节点初始为黑色或白色,我们可以花费一定代价改变节点颜色。目标是让整棵树满足:对于任意节点u,以u为根的子树中黑白节点数量差不超过1。
关键观察点:
- 子树性质具有递归特性:父节点的状态依赖于子节点的状态
- 每个子树有三种可能状态:
- 黑比白多1个(状态0)
- 白比黑多1个(状态1)
- 黑白数量相等(状态2)
- 操作代价需要最小化
输入输出示例分析:
# 输入样例解析 """ 10 1 100 3 2 3 4 # 节点1: 黑(1), 代价100, 3个子节点2,3,4 0 20 1 7 # 节点2: 白(0), 代价20, 1个子节点7 ... """ # 输出应为15 (改变节点6和9的颜色)2. 树形DP框架搭建
树形DP是解决树结构问题的利器,我们需要自底向上计算每个节点的最优状态。
2.1 数据结构设计
首先设计合适的数据结构存储树和DP状态:
from typing import List, Tuple class TreeNode: def __init__(self, idx: int): self.idx = idx self.color: int = 0 # 0: white, 1: black self.cost: int = 0 # 改变颜色的代价 self.children: List[TreeNode] = [] self.size: int = 1 # 子树大小 # DP状态: [黑多1, 白多1, 数量相等] self.dp = [float('inf')] * 32.2 递归DP实现
核心递归函数处理每个节点的状态转移:
def dfs(node: TreeNode): total = 0 heap = [] # 用于贪心选择的最小堆 for child in node.children: dfs(child) node.size += child.size if child.size % 2 == 1: # 奇数大小的子树 heapq.heappush(heap, child.dp[1] - child.dp[0]) total += child.dp[0] else: # 偶数大小的子树 total += child.dp[2] # 考虑当前节点是否需要变色 if node.color == 0: # 白色 heapq.heappush(heap, node.cost) else: # 黑色 total += node.cost heapq.heappush(heap, -node.cost) # 贪心选择最小的k个差值 k = len(heap) for _ in range(k // 2): total += heapq.heappop(heap) # 根据子树大小奇偶性设置状态 if node.size % 2 == 1: # 奇数 node.dp[0] = total if heap: # 还能再取一个 node.dp[1] = total + heap[0] else: # 偶数 node.dp[2] = total3. 贪心优化与优先队列实现
原题解使用了优先队列(最小堆)来优化选择过程,这在Python中可以用heapq模块实现。
关键优化点:
- 只处理奇数大小的子树(它们会影响父节点的状态)
- 通过维护差值的最小堆,确保每次选择都是局部最优
import heapq def build_tree(input_data: List[str]) -> TreeNode: """根据输入数据构建树结构""" n = int(input_data[0]) nodes = [TreeNode(i) for i in range(n+1)] # 1-based索引 for i in range(1, n+1): parts = list(map(int, input_data[i].split())) nodes[i].color = parts[0] nodes[i].cost = parts[1] child_count = parts[2] for j in range(child_count): child_idx = parts[3+j] nodes[i].children.append(nodes[child_idx]) return nodes[1] # 返回根节点4. 完整解决方案与测试
将各部分组合起来,并添加输入输出处理:
def solve_perfect_tree(input_data: List[str]) -> int: root = build_tree(input_data) dfs(root) return min(root.dp) # 测试用例 test_input = [ "10", "1 100 3 2 3 4", "0 20 1 7", "0 5 2 5 6", "0 8 1 10", "0 7 0", "0 1 1 2", "0 15 0", "0 13 0", "1 8 0", "0 2 0" ] print(solve_perfect_tree(test_input)) # 应输出155. 工程化改进与性能优化
将代码工程化,使其更易于维护和扩展:
5.1 模块化设计
class PerfectTreeSolver: def __init__(self): self.root = None def load_input(self, input_data: List[str]): self.root = build_tree(input_data) def solve(self) -> int: if not self.root: raise ValueError("Input not loaded") dfs(self.root) return min(self.root.dp) @staticmethod def from_file(file_path: str) -> 'PerfectTreeSolver': with open(file_path) as f: input_data = [line.strip() for line in f if line.strip()] solver = PerfectTreeSolver() solver.load_input(input_data) return solver5.2 性能考量
对于N=1e5的大规模数据:
- Python的递归深度限制可能成为问题,可以改用迭代DFS
- 使用更高效的数据结构减少常数时间
# 迭代版DFS实现 def dfs_iterative(root: TreeNode): stack = [(root, False)] post_order = [] while stack: node, visited = stack.pop() if visited: # 后序处理 process_node(node) post_order.append(node) else: stack.append((node, True)) # 逆序压栈保证处理顺序 for child in reversed(node.children): stack.append((child, False)) return post_order6. 测试用例设计与验证
完善的测试是工程化的重要部分:
import unittest class TestPerfectTree(unittest.TestCase): def test_small_case(self): input_data = [ "3", "1 10 2 2 3", # 黑(10) -> 两个孩子 "0 5 0", # 白(5) "0 5 0" # 白(5) ] self.assertEqual(solve_perfect_tree(input_data), 0) # 已经完美 def test_sample_case(self): solver = PerfectTreeSolver() solver.load_input(test_input) self.assertEqual(solver.solve(), 15) if __name__ == "__main__": unittest.main()7. 可视化与调试技巧
添加可视化功能帮助理解中间结果:
def print_tree(node: TreeNode, indent=0): color_map = {0: "W", 1: "B"} print(" " * indent + f"Node {node.idx}({color_map[node.color]}, {node.cost}): " f"dp={node.dp}, size={node.size}") for child in node.children: print_tree(child, indent + 1) # 在solve_perfect_tree中添加: # print_tree(root) # 调试时查看树结构通过这个项目,我们不仅解决了算法问题,还实践了如何将竞赛代码转化为可维护的Python工程。关键在于理解算法本质后,用适合Python的方式重新实现,同时保持代码的清晰性和模块化。