题目链接

颜色交错路径计数

题目描述

给定一棵由 个节点组成的树,每个节点被染成红色('R')或黑色('B')。我们需要统计树中“颜色交错”的简单路径的总数。一条路径如果任意相邻的两个节点颜色都不同,则被称为颜色交错路径。单个节点自身也算作一条长度为1的路径。

解题思路

本题要求计算树中所有颜色交错的路径数量。一个朴素的想法是枚举所有节点对 ,然后检查它们之间的唯一路径是否满足颜色交错的条件。但这种方法的时间复杂度至少是 ,对于 较大的情况会超时,因此我们需要一种更高效的算法,例如树形动态规划。

核心思想是将所有路径进行划分,确保每条路径只被计算一次。一个常见的划分方法是,将树进行定根处理(例如,任选节点1为根),然后根据每条路径中“深度最浅的节点”(即最靠近根的节点)来进行归类。这样,每条路径都会被唯一地划分到其最高节点上进行统计。

对于树中的任意一个节点 ,以它为最高节点的颜色交错路径可以分为两类:

  1. 向下路径:路径的一个端点是 ,另一个端点在 的子树中。
  2. 跨子树路径:路径的两个端点分别位于 的两个不同子节点的子树中,路径经过

为了实现这个统计,我们设计一个深度优先搜索(DFS)函数,该函数在进行后序遍历的同时完成计算。我们定义一个 DP 状态: :表示以节点 为一个端点,且完全在 的子树中的颜色交错路径的数量。

DP 状态转移:dfs(u, parent) 函数中,我们计算

  • 首先,路径只包含 节点本身,所以 初始化为 1。
  • 然后,遍历 的所有子节点 。如果 的颜色不同,那么从 出发的所有向下交错路径,都可以和 连接起来,形成新的从 出发的向下交错路径。因此,我们将 累加到 上。

路径计数: 在计算完 的所有子节点的 值之后,我们就可以在节点 这里统计所有以它为最高节点的路径了。

  1. 向下路径:根据我们的定义,从 出发的向下交错路径总数就是 。我们将它计入总答案。
  2. 跨子树路径:这类路径连接了 的两个不同子树。假设 的两个不同子节点,且 与它们的颜色都不同。那么,从 子树出发到 的路径有 条,从 子树出发到 的路径有 条。这两组路径可以在 点拼接,形成 条新的跨子树路径。 因此,我们需要计算所有满足条件的子节点对 值乘积之和:。 这个求和可以通过一个数学技巧简化:

我们将一个全局变量 total_paths 用于累加所有节点的贡献。在 dfs(u, parent) 的最后,我们将上述两类路径的数量加入 total_paths,并返回 值给父节点使用。

代码

#include <iostream>
#include <vector>
#include <string>

using namespace std;

vector<int> adj[200005];
string colors;
long long total_paths = 0;

// DFS函数返回以u为端点,在其子树中的交错路径数
long long dfs_count(int u, int p) {
    // dp[u]: 以u为端点的向下交错路径数
    long long dp_u = 1;
    
    vector<long long> valid_child_dp_values;
    
    for (int v : adj[u]) {
        if (v == p) continue;
        
        long long dp_v = dfs_count(v, u);
        
        if (colors[u - 1] != colors[v - 1]) {
            dp_u += dp_v;
            valid_child_dp_values.push_back(dp_v);
        }
    }
    
    // 1. 计入以u为最高点的向下路径
    total_paths += dp_u;
    
    // 2. 计入以u为最高点的跨子树路径
    long long sum_of_dp = 0;
    long long sum_of_dp_squares = 0;
    for (long long val : valid_child_dp_values) {
        sum_of_dp += val;
    }
    for (long long val : valid_child_dp_values) {
        sum_of_dp_squares += val * val;
    }
    
    long long cross_paths = (sum_of_dp * sum_of_dp - sum_of_dp_squares) / 2;
    total_paths += cross_paths;
    
    return dp_u;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    cin >> colors;
    
    dfs_count(1, 0);
    
    cout << total_paths << endl;

    return 0;
}
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

public class Main {
    static List<Integer>[] adj;
    static String colors;
    static long totalPaths = 0;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();

        adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) {
            adj[i] = new ArrayList<>();
        }

        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj[u].add(v);
            adj[v].add(u);
        }
        colors = sc.next();
        
        dfsCount(1, 0);
        
        System.out.println(totalPaths);
    }

    private static long dfsCount(int u, int p) {
        // dp_u: 以u为端点的向下交错路径数
        long dpU = 1;
        
        List<Long> validChildDpValues = new ArrayList<>();
        
        for (int v : adj[u]) {
            if (v == p) continue;
            
            long dpV = dfsCount(v, u);
            
            if (colors.charAt(u - 1) != colors.charAt(v - 1)) {
                dpU += dpV;
                validChildDpValues.add(dpV);
            }
        }
        
        // 1. 计入以u为最高点的向下路径
        totalPaths += dpU;
        
        // 2. 计入以u为最高点的跨子树路径
        long sumOfDp = 0;
        for (long val : validChildDpValues) {
            sumOfDp += val;
        }
        
        long sumOfDpSquares = 0;
        for (long val : validChildDpValues) {
            sumOfDpSquares += val * val;
        }
        
        long crossPaths = (sumOfDp * sumOfDp - sumOfDpSquares) / 2;
        totalPaths += crossPaths;
        
        return dpU;
    }
}
import sys

# 增加递归深度限制
sys.setrecursionlimit(200005)

import sys

# 增加递归深度限制
sys.setrecursionlimit(200005)

n=int(input())

adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
    u, v = map(int, input().split())
    adj[u].append(v)
    adj[v].append(u)

colors = " " + input().strip() # 1-indexed

# 使用列表或字典作为全局变量的替代
result_container = {'total_paths': 0}

def dfs_count(u, p):
    # dp_u: 以u为端点的向下交错路径数
    dp_u = 1
    
    valid_child_dp_values = []
    
    for v in adj[u]:
        if v == p:
            continue
        
        dp_v = dfs_count(v, u)
        
        if colors[u] != colors[v]:
            dp_u += dp_v
            valid_child_dp_values.append(dp_v)
    
    # 1. 计入以u为最高点的向下路径
    result_container['total_paths'] += dp_u
    
    # 2. 计入以u为最高点的跨子树路径
    sum_of_dp = sum(valid_child_dp_values)
    sum_of_dp_squares = sum(val * val for val in valid_child_dp_values)
    
    cross_paths = (sum_of_dp * sum_of_dp - sum_of_dp_squares) // 2
    result_container['total_paths'] += cross_paths
    
    return dp_u

dfs_count(1, 0)
print(result_container['total_paths'])

算法及复杂度

  • 算法:树形动态规划(Tree DP)、深度优先搜索(DFS)
  • 时间复杂度:,其中 是节点的数量。我们需要遍历每个节点和每条边一次。
  • 空间复杂度:,主要用于存储树的邻接表和DFS的递归栈深度。