题目链接

PEEK59 机房传信

题目描述

给定一个由 个节点构成的树形网络。每个节点 转发信息会产生一个延迟,延迟时间等于该节点的度数(即与之直接相连的边的数量)。信息在两点 之间的传递时间,定义为 之间唯一路径上所有节点(包括 )的延迟时间之和。

你需要处理 次独立的查询,每次给定两个节点 ,求出它们之间的信息传递最短时间。

解题思路

由于网络是树形结构,任意两点 之间的路径是唯一的,因此“最短时间”就是这条唯一路径上的权值(延迟)之和。节点 的权值等于其度数 deg[i]

对每次查询都进行一次 DFS/BFS 来寻找路径并求和,时间复杂度为 ,会超时。这是一个典型的树上路径问题,可以使用最近公共祖先(LCA) 结合倍增(Binary Lifting) 来高效解决。

算法的核心思想是:任意节点 之间的路径都可以被它们的最近公共祖先 分为两段:。因此,路径 的总权值和为: path_sum(u, v) = path_sum(u, l) + path_sum(v, l) - deg[l] (因为 在两条路径中都被计算了一次,所以要减去一次它的权值)。

为了快速计算任意节点到其某个祖先的路径和,我们可以进行预处理:

  1. 预处理

    • 首先,根据输入的边构建邻接表,并计算所有节点的度数 deg[i]
    • 从任意节点(如节点1)作为根,进行一次深度优先搜索(DFS)。在DFS中,为每个节点 计算:
      • 深度 depth[u]
      • 父节点 up[0][u]
      • 从根到节点 的路径权值和 dist[u],其递推式为 dist[u] = dist[parent] + deg[u]
    • 利用 up[0] 数组和递推关系 up[p][u] = up[p-1][up[p-1][u]] 构建完整的倍增表,用于快速查询LCA。此过程复杂度为
  2. 查询: 对于每次查询

    • 使用倍增表在 时间内找到
    • 利用预计算的 dist 数组,在 时间内计算路径总和。路径 的权值和为 dist[u] - dist[l] + deg[l]。因此,总路径和的计算公式为: total_sum = (dist[u] - dist[l] + deg[l]) + (dist[v] - dist[l] + deg[l]) - deg[l] 化简后得到: total_sum = dist[u] + dist[v] - 2 * dist[l] + deg[l]
    • 如果查询的 是同一个节点,该公式也成立,结果为 deg[u]

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;

const int MAXN = 100005;
const int LOGN = 18;

vector<int> adj[MAXN];
int deg[MAXN];
long long dist[MAXN];
int depth[MAXN];
int up[LOGN][MAXN];
int n, m;

void dfs(int u, int p, int d, long long current_dist) {
    depth[u] = d;
    up[0][u] = p;
    dist[u] = current_dist + deg[u];
    for (int v : adj[u]) {
        if (v != p) {
            dfs(v, u, d + 1, dist[u]);
        }
    }
}

int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u, v);
    
    for (int i = LOGN - 1; i >= 0; --i) {
        if (depth[u] - (1 << i) >= depth[v]) {
            u = up[i][u];
        }
    }

    if (u == v) return u;

    for (int i = LOGN - 1; i >= 0; --i) {
        if (up[i][u] != up[i][v]) {
            u = up[i][u];
            v = up[i][v];
        }
    }
    return up[0][u];
}

int main() {
    cin >> n >> m;
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
        deg[u]++;
        deg[v]++;
    }

    dfs(1, 0, 1, 0);

    for (int i = 1; i < LOGN; ++i) {
        for (int j = 1; j <= n; ++j) {
            up[i][j] = up[i - 1][up[i - 1][j]];
        }
    }

    for (int i = 0; i < m; ++i) {
        int u, v;
        cin >> u >> v;
        if (u == v) {
            cout << deg[u] << "\n";
            continue;
        }
        int l = lca(u, v);
        long long result = dist[u] + dist[v] - 2 * dist[l] + deg[l];
        cout << result << "\n";
    }

    return 0;
}
import java.util.ArrayList;
import java.util.Scanner;

public class Main {
    static final int MAXN = 100005;
    static final int LOGN = 18;
    static ArrayList<Integer>[] adj = new ArrayList[MAXN];
    static int[] deg = new int[MAXN];
    static long[] dist = new long[MAXN];
    static int[] depth = new int[MAXN];
    static int[][] up = new int[LOGN][MAXN];
    static int n, m;

    static void dfs(int u, int p, int d, long currentDist) {
        depth[u] = d;
        up[0][u] = p;
        dist[u] = currentDist; // dist[parent] is passed as currentDist
        for (int v : adj[u]) {
            if (v != p) {
                dfs(v, u, d + 1, dist[u] + deg[u]);
            }
        }
    }
    
    // A slightly different DFS to calculate dist correctly
    static void dfs_dist(int u, int p) {
        for(int v : adj[u]) {
            if (v != p) {
                dist[v] = dist[u] + deg[v];
                dfs_dist(v, u);
            }
        }
    }


    static int lca(int u, int v) {
        if (depth[u] < depth[v]) {
            int temp = u; u = v; v = temp;
        }

        for (int i = LOGN - 1; i >= 0; i--) {
            if (depth[u] - (1 << i) >= depth[v]) {
                u = up[i][u];
            }
        }

        if (u == v) return u;

        for (int i = LOGN - 1; i >= 0; i--) {
            if (up[i][u] != up[i][v]) {
                u = up[i][u];
                v = up[i][v];
            }
        }
        return up[0][u];
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        m = sc.nextInt();

        for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();

        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj[u].add(v);
            adj[v].add(u);
            deg[u]++;
            deg[v]++;
        }

        dfs(1, 0, 1, 0); // DFS for depth and parents
        
        dist[1] = deg[1];
        dfs_dist(1, 0); // DFS to calculate path sums from root

        for (int i = 1; i < LOGN; i++) {
            for (int j = 1; j <= n; j++) {
                up[i][j] = up[i - 1][up[i - 1][j]];
            }
        }

        for (int i = 0; i < m; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            if (u == v) {
                System.out.println(deg[u]);
                continue;
            }
            int l = lca(u, v);
            long result = dist[u] + dist[v] - 2 * dist[l] + deg[l];
            System.out.println(result);
        }
    }
}
import sys

sys.setrecursionlimit(100005)

def solve():
    n, m = map(int, input().split())
    
    MAXN = n + 1
    LOGN = (n).bit_length()
    
    adj = [[] for _ in range(MAXN)]
    deg = [0] * MAXN
    for _ in range(n - 1):
        u, v = map(int, input().split())
        adj[u].append(v)
        adj[v].append(u)
        deg[u] += 1
        deg[v] += 1

    depth = [0] * MAXN
    up = [[0] * MAXN for _ in range(LOGN)]
    dist = [0] * MAXN
    
    # DFS to build parent table, depths and path sums from root
    q = [(1, 0, 1, 0)] # u, p, d, current_dist
    visited = {1}
    head = 0
    while head < len(q):
        u, p, d, current_dist = q[head]
        head += 1
        
        depth[u] = d
        up[0][u] = p
        dist[u] = current_dist + deg[u]
        
        for v in adj[u]:
            if v != p:
                visited.add(v)
                q.append((v, u, d + 1, dist[u]))
    
    for i in range(1, LOGN):
        for j in range(1, n + 1):
            up[i][j] = up[i - 1][up[i - 1][j]]
            
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        
        for i in range(LOGN - 1, -1, -1):
            if depth[u] - (1 << i) >= depth[v]:
                u = up[i][u]
        
        if u == v:
            return u
            
        for i in range(LOGN - 1, -1, -1):
            if up[i][u] != up[i][v]:
                u = up[i][u]
                v = up[i][v]
                
        return up[0][u]

    results = []
    for _ in range(m):
        u, v = map(int, input().split())
        if u == v:
            results.append(str(deg[u]))
            continue
        l = lca(u, v)
        result = dist[u] + dist[v] - 2 * dist[l] + deg[l]
        results.append(str(result))
    
    print("\n".join(results))

solve()

算法及复杂度

  • 算法:最近公共祖先(LCA) + 倍增法
  • 时间复杂度。预处理阶段需要 ,后续的 次查询每次需要
  • 空间复杂度,主要用于存储倍增表和其它辅助数组。