小红的根节点个数

[题目链接](https://www.nowcoder.com/practice/bcc26fa906c64a3189939669f446dd73)

思路

问题建模

为节点 的权值的因子个数。当以节点 为根时,对于每条边 的父节点),需满足

换根分析

固定以节点 1 为根,对每条边 (其中 的父节点)分析:

  • 当根 不在 中时,方向为 是子节点),约束为
  • 当根 中时,方向翻转为 是子节点),约束为

因此对每条边 有:

  • :根 必须 中(否则约束违反);
  • :根 不能 中;
  • 若两者都违反:不存在合法根,答案为 0(事实上不可能两者同时成立,因为不能既 )。

求合法根集合

必须包含的约束:所有需要"根在 中"的节点 形成一个集合

合法根需同时在所有这些子树的交集中。利用欧拉序(DFS 入/出时间)判断子树包含关系:两个子树要么嵌套(一个包含另一个),要么不相交。

  • 中存在两个不可比较的节点(即两个子树不相交),则答案为 0;
  • 否则, 中深度最大的节点 对应的子树 就是初始合法集合。

必须排除的约束:对所有"根不能在 中"的节点 ,从合法集合中删去 中的节点。

计数

利用欧拉序, 对应区间 (若 为空,则为 )。

对所有需排除的子树,用差分数组在该区间内标记被排除的位置,最后统计未被覆盖的位置数即为答案。

复杂度

  • 预处理因子数:,其中 为最大权值(约 );
  • 每组测试:
  • 总体:

代码

C++

#include <iostream>
#include <vector>
#include <queue>
#include <stack>
#include <functional>
#include <algorithm>
using namespace std;

const int MAXV = 1000001;
int ndiv[MAXV];

void precompute() {
    for (int i = 1; i < MAXV; i++)
        for (int j = i; j < MAXV; j += i)
            ndiv[j]++;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    precompute();

    int T;
    cin >> T;
    while (T--) {
        int n;
        cin >> n;

        vector<int> a(n+1), d(n+1);
        for (int i = 1; i <= n; i++) {
            cin >> a[i];
            d[i] = ndiv[a[i]];
        }

        vector<vector<int>> adj(n+1);
        for (int i = 0; i < n-1; i++) {
            int u, v; cin >> u >> v;
            adj[u].push_back(v);
            adj[v].push_back(u);
        }

        if (n == 1) { cout << 1 << "\n"; continue; }

        vector<int> par(n+1, 0), depth(n+1, 0), in_time(n+1), out_time(n+1);
        int timer = 0;

        // DFS for euler tour
        function<void(int,int)> dfs = [&](int u, int p) {
            in_time[u] = timer++;
            for (int v : adj[u]) {
                if (v != p) {
                    par[v] = u;
                    depth[v] = depth[u] + 1;
                    dfs(v, u);
                }
            }
            out_time[u] = timer - 1;
        };
        dfs(1, 0);

        auto inSubtree = [&](int r, int v) {
            return in_time[v] <= in_time[r] && in_time[r] <= out_time[v];
        };

        int mustIncDeepest = -1;
        bool impossible = false;
        vector<int> mustExc;

        for (int v = 2; v <= n; v++) {
            int u = par[v];
            bool childBad = (d[v] > 2 * d[u]);
            bool parentBad = (d[u] > 2 * d[v]);

            if (childBad) {
                if (mustIncDeepest == -1) {
                    mustIncDeepest = v;
                } else if (inSubtree(v, mustIncDeepest)) {
                    mustIncDeepest = v;
                } else if (!inSubtree(mustIncDeepest, v)) {
                    impossible = true; break;
                }
            }
            if (parentBad) mustExc.push_back(v);
        }

        if (impossible) { cout << 0 << "\n"; continue; }

        int lo = 0, hi = n - 1;
        if (mustIncDeepest != -1) {
            lo = in_time[mustIncDeepest];
            hi = out_time[mustIncDeepest];
        }

        vector<int> diff(n+2, 0);
        for (int v : mustExc) {
            int vlo = max(in_time[v], lo), vhi = min(out_time[v], hi);
            if (vlo > vhi) continue;
            diff[vlo]++; diff[vhi+1]--;
        }

        int count = 0, cur = 0;
        for (int i = lo; i <= hi; i++) {
            cur += diff[i];
            if (cur == 0) count++;
        }
        cout << count << "\n";
    }
    return 0;
}

Java

import java.util.*;
import java.io.*;

public class Main {
    static final int MAXV = 1000001;
    static int[] ndiv = new int[MAXV];

    static void precompute() {
        for (int i = 1; i < MAXV; i++)
            for (int j = i; j < MAXV; j += i)
                ndiv[j]++;
    }

    public static void main(String[] args) throws IOException {
        precompute();
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringBuilder sb = new StringBuilder();
        int T = Integer.parseInt(br.readLine().trim());

        while (T-- > 0) {
            int n = Integer.parseInt(br.readLine().trim());
            int[] d = new int[n + 1];
            StringTokenizer st = new StringTokenizer(br.readLine());
            for (int i = 1; i <= n; i++)
                d[i] = ndiv[Integer.parseInt(st.nextToken())];

            List<List<Integer>> adj = new ArrayList<>();
            for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());
            for (int i = 0; i < n - 1; i++) {
                StringTokenizer st2 = new StringTokenizer(br.readLine());
                int u = Integer.parseInt(st2.nextToken()), v = Integer.parseInt(st2.nextToken());
                adj.get(u).add(v); adj.get(v).add(u);
            }

            if (n == 1) { sb.append(1).append("\n"); continue; }

            int[] depth = new int[n + 1], inTime = new int[n + 1], outTime = new int[n + 1];
            int[] bfsParent = new int[n + 1];
            boolean[] bfsVis = new boolean[n + 1];
            Queue<Integer> bfsQ = new LinkedList<>();
            bfsQ.add(1); bfsVis[1] = true;
            while (!bfsQ.isEmpty()) {
                int u = bfsQ.poll();
                for (int v : adj.get(u)) {
                    if (!bfsVis[v]) {
                        bfsVis[v] = true; bfsParent[v] = u;
                        depth[v] = depth[u] + 1; bfsQ.add(v);
                    }
                }
            }

            int[] timerArr = {0};
            Deque<Integer> dfs = new ArrayDeque<>();
            boolean[] dfsVis = new boolean[n + 1];
            dfsVis[1] = true; dfs.push(~1); dfs.push(1);
            while (!dfs.isEmpty()) {
                int u = dfs.pop();
                if (u < 0) { outTime[~u] = timerArr[0] - 1; }
                else {
                    inTime[u] = timerArr[0]++;
                    dfs.push(~u);
                    for (int v : adj.get(u))
                        if (v != bfsParent[u] && !dfsVis[v]) {
                            dfsVis[v] = true; dfs.push(~v); dfs.push(v);
                        }
                }
            }

            int mustIncDeepest = -1;
            boolean impossible = false;
            List<Integer> mustExc = new ArrayList<>();

            for (int v = 2; v <= n; v++) {
                int u = bfsParent[v];
                boolean childBad = (d[v] > 2 * d[u]), parentBad = (d[u] > 2 * d[v]);
                if (childBad) {
                    if (mustIncDeepest == -1) mustIncDeepest = v;
                    else if (inTime[mustIncDeepest] <= inTime[v] && inTime[v] <= outTime[mustIncDeepest])
                        mustIncDeepest = v;
                    else if (!(inTime[v] <= inTime[mustIncDeepest] && inTime[mustIncDeepest] <= outTime[v])) {
                        impossible = true; break;
                    }
                }
                if (parentBad) mustExc.add(v);
            }

            if (impossible) { sb.append(0).append("\n"); continue; }

            int lo = 0, hi = n - 1;
            if (mustIncDeepest != -1) { lo = inTime[mustIncDeepest]; hi = outTime[mustIncDeepest]; }

            int[] diff = new int[n + 2];
            for (int v : mustExc) {
                int vlo = Math.max(inTime[v], lo), vhi = Math.min(outTime[v], hi);
                if (vlo > vhi) continue;
                diff[vlo]++; diff[vhi + 1]--;
            }

            int count = 0, cur = 0;
            for (int i = lo; i <= hi; i++) { cur += diff[i]; if (cur == 0) count++; }
            sb.append(count).append("\n");
        }
        System.out.print(sb);
    }
}

样例解析

样例 1(n=5,权值 [8,4,3,2,1],因子数 [4,3,2,2,1]):

  • 以节点 1 为根,各边方向下的约束均满足(最大约为 )。
  • 所有边双向都不超过 2 倍关系,无任何必须包含/排除约束。
  • 答案:所有 5 个节点均合法。

样例 2(n=2,权值 [1,16],因子数 [1,5]):

  • 边 (1,2):,节点 2 不能作为节点 1 的子节点。
  • 必须包含约束:根必须在 中。
  • 答案:仅节点 2 合法。