题目链接

世界树上找米库

题目描述

在一个由 个地点和 条道路构成的树形结构中,我们需要找到所有“Miku”点。

  • Sekai 点:只连接一条道路的地点,即树的叶子节点(度为 1)。
  • Miku 点:必须满足两个条件:
    1. 它不能是 Sekai 点(度大于 1)。
    2. 在所有非 Sekai 点中,它到最近的 Sekai 点的距离是最大的。

任务是找出所有 Miku 点,并按编号从小到大输出。

解题思路

这个问题的核心是计算出树上每一个节点到离它最近的叶子节点的距离。一旦这个距离计算出来,我们就可以根据 Miku 点的定义进行筛选。

这是一个典型的多源广度优先搜索(Multi-Source BFS) 问题。我们可以把所有的叶子节点(Sekai 点)作为 BFS 的初始源点,同时开始进行搜索。通过这种方式,当我们第一次访问到一个节点时,所经过的路径长度就是该节点到最近叶子节点的距离。

算法可以分为以下几个步骤:

  1. 预处理和初始化

    • 构建图的邻接表,并计算每个节点的度数
    • 初始化一个距离数组 ,所有值设为 -1 表示未访问。
    • 创建一个队列,并将所有度为 1 的节点(Sekai 点)加入队列,同时将它们的距离 设为 0。
  2. 多源 BFS 计算距离

    • 执行标准的 BFS 过程。当队列不为空时,出队一个节点
    • 遍历 的所有邻居 。如果 尚未被访问过(即 ),说明我们找到了从某个叶子节点到 的最短路径。
    • 更新 的距离 ,并将 入队。
  3. 筛选 Miku 点

    • BFS 结束后, 数组就存储了每个节点到最近 Sekai 点的距离。
    • 找出所有非 Sekai 点()中,最大的距离值
    • 创建一个列表,用于存放 Miku 点的编号。
    • 再次遍历所有节点 ,如果一个节点满足 并且 ,那么它就是一个 Miku 点,将其加入列表。
    • 由于我们是按节点编号从小到大遍历并添加的,所以最终得到的列表自然是有序的。
  4. 处理特殊情况

    • 当节点数 时,所有节点的度数都为 1,不存在度大于 1 的节点,因此没有 Miku 点。可以直接输出 0。

整个算法对于每个测试用例,只需要一次 BFS 和几次线性扫描,效率很高。

代码

#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>

using namespace std;

void solve() {
    int n;
    cin >> n;
    vector<vector<int>> adj(n + 1);
    vector<int> deg(n + 1, 0);
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
        deg[u]++;
        deg[v]++;
    }

    if (n <= 2) {
        cout << 0 << endl << endl;
        return;
    }

    queue<int> q;
    vector<int> dist(n + 1, -1);

    for (int i = 1; i <= n; ++i) {
        if (deg[i] == 1) {
            q.push(i);
            dist[i] = 0;
        }
    }

    while(!q.empty()){
        int u = q.front();
        q.pop();

        for(int v : adj[u]){
            if(dist[v] == -1){
                dist[v] = dist[u] + 1;
                q.push(v);
            }
        }
    }

    int max_dist = -1;
    for (int i = 1; i <= n; ++i) {
        if (deg[i] > 1) {
            if (dist[i] > max_dist) {
                max_dist = dist[i];
            }
        }
    }

    vector<int> miku_points;
    for (int i = 1; i <= n; ++i) {
        if (deg[i] > 1 && dist[i] == max_dist) {
            miku_points.push_back(i);
        }
    }

    cout << miku_points.size() << endl;
    for (size_t i = 0; i < miku_points.size(); ++i) {
        cout << miku_points[i] << (i == miku_points.size() - 1 ? "" : " ");
    }
    cout << endl;
}

int main() {
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            solve(sc);
        }
    }

    private static void solve(Scanner sc) {
        int n = sc.nextInt();
        ArrayList<Integer>[] adj = new ArrayList[n + 1];
        for (int i = 0; i <= n; i++) {
            adj[i] = new ArrayList<>();
        }
        int[] deg = new int[n + 1];

        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj[u].add(v);
            adj[v].add(u);
            deg[u]++;
            deg[v]++;
        }

        if (n <= 2) {
            System.out.println(0);
            System.out.println();
            return;
        }

        Queue<Integer> q = new LinkedList<>();
        int[] dist = new int[n + 1];
        for (int i = 0; i <= n; i++) {
            dist[i] = -1;
        }

        for (int i = 1; i <= n; i++) {
            if (deg[i] == 1) {
                q.add(i);
                dist[i] = 0;
            }
        }

        while (!q.isEmpty()) {
            int u = q.poll();
            for (int v : adj[u]) {
                if (dist[v] == -1) {
                    dist[v] = dist[u] + 1;
                    q.add(v);
                }
            }
        }

        int maxDist = -1;
        for (int i = 1; i <= n; i++) {
            if (deg[i] > 1) {
                if (dist[i] > maxDist) {
                    maxDist = dist[i];
                }
            }
        }

        List<Integer> mikuPoints = new ArrayList<>();
        for (int i = 1; i <= n; i++) {
            if (deg[i] > 1 && dist[i] == maxDist) {
                mikuPoints.add(i);
            }
        }

        System.out.println(mikuPoints.size());
        for (int i = 0; i < mikuPoints.size(); i++) {
            System.out.print(mikuPoints.get(i) + (i == mikuPoints.size() - 1 ? "" : " "));
        }
        System.out.println();
    }
}
from collections import deque
import sys

def solve():
    line = sys.stdin.readline()
    if not line: return
    n = int(line)
    
    adj = [[] for _ in range(n + 1)]
    deg = [0] * (n + 1)
    for _ in range(n - 1):
        u, v = map(int, sys.stdin.readline().split())
        adj[u].append(v)
        adj[v].append(u)
        deg[u] += 1
        deg[v] += 1

    if n <= 2:
        print(0)
        print()
        return

    q = deque()
    dist = [-1] * (n + 1)

    for i in range(1, n + 1):
        if deg[i] == 1:
            q.append(i)
            dist[i] = 0
    
    while q:
        u = q.popleft()
        for v in adj[u]:
            if dist[v] == -1:
                dist[v] = dist[u] + 1
                q.append(v)

    max_dist = -1
    for i in range(1, n + 1):
        if deg[i] > 1 and dist[i] > max_dist:
            max_dist = dist[i]

    miku_points = []
    for i in range(1, n + 1):
        if deg[i] > 1 and dist[i] == max_dist:
            miku_points.append(i)

    print(len(miku_points))
    print(*miku_points)

def main():
    line = sys.stdin.readline()
    if not line: return
    t = int(line)
    for _ in range(t):
        solve()

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:多源广度优先搜索(Multi-Source BFS)。
  • 时间复杂度,对于每个测试用例,时间复杂度为 ,因为建图、BFS 和后续的扫描都是线性的。 表示所有测试用例的节点数之和。
  • 空间复杂度,主要用于存储邻接表、度数数组、距离数组和队列。