题目链接
题目描述
给定一棵由 个节点组成的树,每个节点被染成红色('R')或黑色('B')。我们需要统计树中“颜色交错”的简单路径的总数。一条路径如果任意相邻的两个节点颜色都不同,则被称为颜色交错路径。单个节点自身也算作一条长度为1的路径。
解题思路
本题要求计算树中所有颜色交错的路径数量。一个朴素的想法是枚举所有节点对 ,然后检查它们之间的唯一路径是否满足颜色交错的条件。但这种方法的时间复杂度至少是
,对于
较大的情况会超时,因此我们需要一种更高效的算法,例如树形动态规划。
核心思想是将所有路径进行划分,确保每条路径只被计算一次。一个常见的划分方法是,将树进行定根处理(例如,任选节点1为根),然后根据每条路径中“深度最浅的节点”(即最靠近根的节点)来进行归类。这样,每条路径都会被唯一地划分到其最高节点上进行统计。
对于树中的任意一个节点 ,以它为最高节点的颜色交错路径可以分为两类:
- 向下路径:路径的一个端点是
,另一个端点在
的子树中。
- 跨子树路径:路径的两个端点分别位于
的两个不同子节点的子树中,路径经过
。
为了实现这个统计,我们设计一个深度优先搜索(DFS)函数,该函数在进行后序遍历的同时完成计算。我们定义一个 DP 状态:
:表示以节点
为一个端点,且完全在
的子树中的颜色交错路径的数量。
DP 状态转移:
在 dfs(u, parent)
函数中,我们计算 。
- 首先,路径只包含
节点本身,所以
初始化为 1。
- 然后,遍历
的所有子节点
。如果
和
的颜色不同,那么从
出发的所有向下交错路径,都可以和
连接起来,形成新的从
出发的向下交错路径。因此,我们将
累加到
上。
路径计数:
在计算完 的所有子节点的
值之后,我们就可以在节点
这里统计所有以它为最高节点的路径了。
- 向下路径:根据我们的定义,从
出发的向下交错路径总数就是
。我们将它计入总答案。
- 跨子树路径:这类路径连接了
的两个不同子树。假设
和
是
的两个不同子节点,且
与它们的颜色都不同。那么,从
子树出发到
的路径有
条,从
子树出发到
的路径有
条。这两组路径可以在
点拼接,形成
条新的跨子树路径。 因此,我们需要计算所有满足条件的子节点对
的
值乘积之和:
。 这个求和可以通过一个数学技巧简化:
。
我们将一个全局变量 total_paths
用于累加所有节点的贡献。在 dfs(u, parent)
的最后,我们将上述两类路径的数量加入 total_paths
,并返回 值给父节点使用。
代码
#include <iostream>
#include <vector>
#include <string>
using namespace std;
vector<int> adj[200005];
string colors;
long long total_paths = 0;
// DFS函数返回以u为端点,在其子树中的交错路径数
long long dfs_count(int u, int p) {
// dp[u]: 以u为端点的向下交错路径数
long long dp_u = 1;
vector<long long> valid_child_dp_values;
for (int v : adj[u]) {
if (v == p) continue;
long long dp_v = dfs_count(v, u);
if (colors[u - 1] != colors[v - 1]) {
dp_u += dp_v;
valid_child_dp_values.push_back(dp_v);
}
}
// 1. 计入以u为最高点的向下路径
total_paths += dp_u;
// 2. 计入以u为最高点的跨子树路径
long long sum_of_dp = 0;
long long sum_of_dp_squares = 0;
for (long long val : valid_child_dp_values) {
sum_of_dp += val;
}
for (long long val : valid_child_dp_values) {
sum_of_dp_squares += val * val;
}
long long cross_paths = (sum_of_dp * sum_of_dp - sum_of_dp_squares) / 2;
total_paths += cross_paths;
return dp_u;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
cin >> colors;
dfs_count(1, 0);
cout << total_paths << endl;
return 0;
}
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
public class Main {
static List<Integer>[] adj;
static String colors;
static long totalPaths = 0;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
adj = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) {
adj[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj[u].add(v);
adj[v].add(u);
}
colors = sc.next();
dfsCount(1, 0);
System.out.println(totalPaths);
}
private static long dfsCount(int u, int p) {
// dp_u: 以u为端点的向下交错路径数
long dpU = 1;
List<Long> validChildDpValues = new ArrayList<>();
for (int v : adj[u]) {
if (v == p) continue;
long dpV = dfsCount(v, u);
if (colors.charAt(u - 1) != colors.charAt(v - 1)) {
dpU += dpV;
validChildDpValues.add(dpV);
}
}
// 1. 计入以u为最高点的向下路径
totalPaths += dpU;
// 2. 计入以u为最高点的跨子树路径
long sumOfDp = 0;
for (long val : validChildDpValues) {
sumOfDp += val;
}
long sumOfDpSquares = 0;
for (long val : validChildDpValues) {
sumOfDpSquares += val * val;
}
long crossPaths = (sumOfDp * sumOfDp - sumOfDpSquares) / 2;
totalPaths += crossPaths;
return dpU;
}
}
import sys
# 增加递归深度限制
sys.setrecursionlimit(200005)
import sys
# 增加递归深度限制
sys.setrecursionlimit(200005)
n=int(input())
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
colors = " " + input().strip() # 1-indexed
# 使用列表或字典作为全局变量的替代
result_container = {'total_paths': 0}
def dfs_count(u, p):
# dp_u: 以u为端点的向下交错路径数
dp_u = 1
valid_child_dp_values = []
for v in adj[u]:
if v == p:
continue
dp_v = dfs_count(v, u)
if colors[u] != colors[v]:
dp_u += dp_v
valid_child_dp_values.append(dp_v)
# 1. 计入以u为最高点的向下路径
result_container['total_paths'] += dp_u
# 2. 计入以u为最高点的跨子树路径
sum_of_dp = sum(valid_child_dp_values)
sum_of_dp_squares = sum(val * val for val in valid_child_dp_values)
cross_paths = (sum_of_dp * sum_of_dp - sum_of_dp_squares) // 2
result_container['total_paths'] += cross_paths
return dp_u
dfs_count(1, 0)
print(result_container['total_paths'])
算法及复杂度
- 算法:树形动态规划(Tree DP)、深度优先搜索(DFS)
- 时间复杂度:
,其中
是节点的数量。我们需要遍历每个节点和每条边一次。
- 空间复杂度:
,主要用于存储树的邻接表和DFS的递归栈深度。