小红的树上路径

[题目链接](https://www.nowcoder.com/practice/bdb252f393064fb08b93b111ead1a3f8)

思路

给定一棵 个节点的树,每个节点有权值 )。要求选择一条长度为 2 的路径(经过 3 个节点 --,其中 是中间节点),使得 的因子个数不小于 ,求方案数。

关键观察

长度为 2 的路径 -- 等价于:选择一个中间节点 ,再从 的邻居中选两个不同的节点 。因此我们枚举每个节点 作为路径中心。

因子个数公式

对于正整数 ,若其质因数分解为 ,则因子个数为:

$$

三个数的乘积 的因子个数,等于将三者的质因数分解合并后,对每个质数的指数之和加 1 再相乘。

按值分组优化

暴力枚举所有邻居对的复杂度为 ,在星形树上退化为 ,无法通过。

注意到 ,节点的权值最多只有 100 种不同取值。对于每个中间节点 ,将其邻居按权值分组,记 为权值等于 的邻居个数。

枚举所有值对 ,计算 的因子个数是否

  • ,贡献
  • ,贡献

不同取值最多 100 种,所以每个节点的值对枚举量为 ,总复杂度为 ,其中 是 100 以内的质数个数。

样例演示

输入的树结构为 ---,权值分别为

以节点 2 为中心:邻居为 ,乘积 ,因子数

以节点 3 为中心:邻居为 ,乘积 ,因子数 。符合条件。

答案为

复杂度分析

  • 时间复杂度:,其中 为权值上界, 为 100 以内质数个数。
  • 空间复杂度:

代码

#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    long long k;
    cin >> n >> k;

    int primes[] = {2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97};
    int P = 25;

    // 预处理 1~100 每个值的质因数分解
    int valFac[101][25] = {};
    for (int v = 1; v <= 100; v++) {
        int x = v;
        for (int i = 0; i < P; i++) {
            while (x % primes[i] == 0) {
                valFac[v][i]++;
                x /= primes[i];
            }
        }
    }

    vector<int> a(n + 1);
    for (int i = 1; i <= n; i++) cin >> a[i];

    vector<vector<int>> adj(n + 1);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    long long ans = 0;

    for (int v = 1; v <= n; v++) {
        if ((int)adj[v].size() < 2) continue;

        // 按权值分组
        map<int, long long> cnt;
        for (int u : adj[v]) cnt[a[u]]++;

        vector<pair<int, long long>> groups(cnt.begin(), cnt.end());
        int G = groups.size();

        for (int i = 0; i < G; i++) {
            int val1 = groups[i].first;
            long long c1 = groups[i].second;
            // 合并 val1 和 a[v] 的质因数分解
            int merged[25];
            for (int p = 0; p < P; p++)
                merged[p] = valFac[val1][p] + valFac[a[v]][p];

            for (int j = i; j < G; j++) {
                int val2 = groups[j].first;
                long long c2 = groups[j].second;
                // 计算因子个数
                long long dc = 1;
                for (int p = 0; p < P; p++)
                    dc *= (merged[p] + valFac[val2][p] + 1);

                if (dc >= k) {
                    if (i == j) ans += c1 * (c1 - 1) / 2;
                    else ans += c1 * c2;
                }
            }
        }
    }

    cout << ans << endl;
    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        long k = Long.parseLong(st.nextToken());

        int[] primes = {2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97};
        int P = 25;

        // 预处理 1~100 每个值的质因数分解
        int[][] valFac = new int[101][P];
        for (int v = 1; v <= 100; v++) {
            int x = v;
            for (int i = 0; i < P; i++) {
                while (x % primes[i] == 0) {
                    valFac[v][i]++;
                    x /= primes[i];
                }
            }
        }

        int[] a = new int[n + 1];
        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) a[i] = Integer.parseInt(st.nextToken());

        List<List<Integer>> adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());
        for (int i = 0; i < n - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            adj.get(u).add(v);
            adj.get(v).add(u);
        }

        long ans = 0;

        for (int v = 1; v <= n; v++) {
            List<Integer> nbrs = adj.get(v);
            int deg = nbrs.size();
            if (deg < 2) continue;

            // 按权值分组
            Map<Integer, Long> cnt = new HashMap<>();
            for (int u : nbrs) cnt.merge(a[u], 1L, Long::sum);

            List<Map.Entry<Integer, Long>> groups = new ArrayList<>(cnt.entrySet());
            int G = groups.size();
            int[] facV = valFac[a[v]];

            for (int i = 0; i < G; i++) {
                int val1 = groups.get(i).getKey();
                long c1 = groups.get(i).getValue();
                int[] f1 = valFac[val1];
                int[] merged = new int[P];
                for (int p = 0; p < P; p++)
                    merged[p] = f1[p] + facV[p];

                for (int j = i; j < G; j++) {
                    int val2 = groups.get(j).getKey();
                    long c2 = groups.get(j).getValue();
                    int[] f2 = valFac[val2];
                    long dc = 1;
                    for (int p = 0; p < P; p++)
                        dc *= (merged[p] + f2[p] + 1);

                    if (dc >= k) {
                        if (i == j) ans += c1 * (c1 - 1) / 2;
                        else ans += c1 * c2;
                    }
                }
            }
        }

        System.out.println(ans);
    }
}