题目链接

小红的树上路径

思路分析

1. 问题建模

题目要求我们寻找一棵树中所有长度为 2 的路径,这些路径由三个节点(我们称之为 u-v-w)组成,并满足特定条件。

  • 路径结构:一条长度为 2 的路径 u-v-w 由一个中心节点 v 和它的两个不同邻居 uw 构成。
  • 路径唯一性:每一条这样的路径都由其中心节点 v 和邻居对 {u, w} 唯一确定。路径 u-v-ww-v-u 被视为同一条路径。
  • 条件:路径上三个节点权值的乘积 val(u) * val(v) * val(w) 的因子数量必须不小于 k

基于此,一个直接的思路是:遍历树中的每一个节点,让它充当中心节点 v,然后考虑其所有邻居对 {u, w},检查它们构成的路径是否满足条件。

2. 核心挑战:因子数量计算

计算一个数 X 的因子数量,需要知道它的标准质因数分解。如果 ,那么 X 的因子数量为

对于路径 u-v-w,其权值乘积 。要计算 的因子数,我们需要 的质因数分解。 的质因数分解可以通过合并 val(u), val(v), val(w) 各自的质因数分解来得到(即对每个质因子,指数相加)。

3. 性能瓶颈与优化

一个朴素的算法是:

  1. 遍历每个节点 v (从 0 到 n-1)。
  2. 获取 v 的邻居列表 adj[v]
  3. 如果 adj[v] 的大小小于 2,则 v 不能作为中心节点,跳过。
  4. 遍历所有邻居对 (adj[v][i], adj[v][j]) 其中 i < j
  5. 对于每一对,计算乘积的因子数并与 k 比较。

这个算法的瓶颈在于步骤 4。如果一个节点 v 的度数为 deg(v),我们需要进行 次检查。在最坏的情况下(例如星形图),一个节点的度数可能是 ,这会导致 次检查,对于 来说太慢了。

优化思路:在为中心节点 v 检查其邻居对时,我们发现,如果很多邻居的权值是相同的,那么它们是等价的。 我们可以将 v 的邻居按其权值进行分组。

  • 假设 vc1 个权值为 w1 的邻居,c2 个权值为 w2 的邻居,等等。
  • 同权值邻居对:对于权值为 w1c1 个邻居,我们可以从中任选两个构成路径 w1-v-w1。共有 种选法。我们只需计算一次 val(v) * w1 * w1 的因子数,如果满足条件,就将 加入总答案。
  • 不同权值邻居对:对于权值为 w1c1 个邻居和权值为 w2c2 个邻居,我们可以构成 c1 * c2 条路径 w1-v-w2。我们只需计算一次 val(v) * w1 * w2 的因子数,如果满足条件,就将 c1 * c2 加入总答案。

通过这种方式,对于中心节点 v,需要检查的次数从 deg(v)^2 级别降低到 (v的邻居的不同权值数)^2 级别,这是一个巨大的优化。

4. 算法步骤

  1. 预计算: a. 使用筛法预计算出 每个数的最小质因子 (SPF)。这可以让我们在 的时间内得到任何数的质因数分解。 b. 读入所有节点的权值,并利用 SPF 数组为每个节点的权值预先计算好其质因数分解,存起来备用。
  2. 建图:读入边,建立树的邻接表表示。
  3. 主循环: a. 初始化总方案数 ans = 0。 b. 遍历每个节点 v (从 0 到 n-1),将其作为路径中心。 c. 如果 v 的度数小于 2,跳过。 d. 邻居分组:创建一个 map<int, int> 来统计 v 的邻居中每种权值的出现次数。 e. 将 map 中的 (权值, 次数) 对提取到一个 vector 中,方便后续遍历。 f. 组合计数: i. 遍历这个 vector。对于每个 (w1, c1),如果 c1 >= 2,检查 w1-v-w1 路径,若满足条件则 ans += c1 * (c1 - 1) / 2。 ii. 使用一个嵌套循环遍历 vector 中的权值对。对于 (w1, c1)(w2, c2),检查 w1-v-w2 路径,若满足条件则 ans += c1 * c2
  4. 因子数检查函数: a. 该函数接受三个节点的质因数分解和一个阈值 k。 b. 它合并三个质因数分解(即将各质因子的指数相加)。 c. 然后计算总因子数。注意,因子数可能非常大,会超出 long long。因此在计算累乘时,需要使用 __int128_t (C++) 或在每次乘法前判断是否会超过 k 来避免溢出。
  5. 输出最终的 ans

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <map>
#include <algorithm>

using namespace std;

const int MAX_VAL = 1000001;
vector<int> spf(MAX_VAL);
vector<vector<pair<int, int>>> factorizations;

void sieve() {
    iota(spf.begin(), spf.end(), 0);
    for (int i = 2; i * i < MAX_VAL; ++i) {
        if (spf[i] == i) { // i is prime
            for (int j = i * i; j < MAX_VAL; j += i) {
                if (spf[j] == j) {
                    spf[j] = i;
                }
            }
        }
    }
}

vector<pair<int, int>> get_factors(int n) {
    vector<pair<int, int>> factors;
    if (n == 1) return factors;
    int current_p = spf[n];
    int count = 1;
    n /= spf[n];
    while (n != 1) {
        if (spf[n] == current_p) {
            count++;
        } else {
            factors.push_back({current_p, count});
            current_p = spf[n];
            count = 1;
        }
        n /= spf[n];
    }
    factors.push_back({current_p, count});
    return factors;
}

bool check(int v_val, int w1, int w2, long long k) {
    map<int, int> total_factors_map;
    
    auto factors_v = factorizations[v_val];
    auto factors_1 = factorizations[w1];
    auto factors_2 = factorizations[w2];

    for (const auto& p : factors_v) total_factors_map[p.first] += p.second;
    for (const auto& p : factors_1) total_factors_map[p.first] += p.second;
    for (const auto& p : factors_2) total_factors_map[p.first] += p.second;
    
    __int128_t num_divs = 1;
    for (const auto& p : total_factors_map) {
        num_divs *= (p.second + 1);
        if (num_divs >= k) {
            return true;
        }
    }
    return num_divs >= k;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    sieve();
    factorizations.resize(MAX_VAL);
    for (int i = 1; i < MAX_VAL; ++i) {
        factorizations[i] = get_factors(i);
    }

    int n;
    long long k;
    cin >> n >> k;
    vector<int> a(n + 1);
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
    }

    vector<vector<int>> adj(n + 1);
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    long long ans = 0;

    for (int v = 1; v <= n; ++v) {
        if (adj[v].size() < 2) continue;

        map<int, int> neighbor_counts;
        for (int neighbor : adj[v]) {
            neighbor_counts[a[neighbor]]++;
        }

        vector<pair<int, int>> unique_weights;
        for (const auto& p : neighbor_counts) {
            unique_weights.push_back(p);
        }

        for (size_t i = 0; i < unique_weights.size(); ++i) {
            int w1 = unique_weights[i].first;
            long long c1 = unique_weights[i].second;

            if (c1 >= 2) {
                if (check(a[v], w1, w1, k)) {
                    ans += c1 * (c1 - 1) / 2;
                }
            }

            for (size_t j = i + 1; j < unique_weights.size(); ++j) {
                int w2 = unique_weights[j].first;
                long long c2 = unique_weights[j].second;
                if (check(a[v], w1, w2, k)) {
                    ans += c1 * c2;
                }
            }
        }
    }

    cout << ans << endl;

    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static final int MAX_VAL = 1000001;
    static int[] spf = new int[MAX_VAL];
    static List<Map<Integer, Integer>> factorizations = new ArrayList<>();

    static void sieve() {
        for (int i = 0; i < MAX_VAL; i++) spf[i] = i;
        for (int i = 2; i * i < MAX_VAL; i++) {
            if (spf[i] == i) { // i is prime
                for (int j = i * i; j < MAX_VAL; j += i) {
                    if (spf[j] == j) {
                        spf[j] = i;
                    }
                }
            }
        }
    }

    static Map<Integer, Integer> getFactors(int n) {
        Map<Integer, Integer> factors = new HashMap<>();
        if (n == 1) return factors;
        int temp = n;
        while (temp != 1) {
            int p = spf[temp];
            factors.put(p, factors.getOrDefault(p, 0) + 1);
            temp /= p;
        }
        return factors;
    }

    static boolean check(int vVal, int w1, int w2, long k) {
        Map<Integer, Integer> totalFactorsMap = new HashMap<>();
        
        for (Map.Entry<Integer, Integer> entry : factorizations.get(vVal).entrySet()) {
            totalFactorsMap.put(entry.getKey(), totalFactorsMap.getOrDefault(entry.getKey(), 0) + entry.getValue());
        }
        for (Map.Entry<Integer, Integer> entry : factorizations.get(w1).entrySet()) {
            totalFactorsMap.put(entry.getKey(), totalFactorsMap.getOrDefault(entry.getKey(), 0) + entry.getValue());
        }
        for (Map.Entry<Integer, Integer> entry : factorizations.get(w2).entrySet()) {
            totalFactorsMap.put(entry.getKey(), totalFactorsMap.getOrDefault(entry.getKey(), 0) + entry.getValue());
        }

        long numDivs = 1;
        for (int exp : totalFactorsMap.values()) {
            // Check for overflow before multiplication
            if ((double) numDivs * (exp + 1) >= k) {
                return true;
            }
            numDivs *= (exp + 1);
        }
        return numDivs >= k;
    }

    public static void main(String[] args) throws IOException {
        FastReader sc = new FastReader();
        sieve();
        for (int i = 0; i < MAX_VAL; i++) {
            factorizations.add(getFactors(i));
        }

        int n = sc.nextInt();
        long k = sc.nextLong();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            a[i] = sc.nextInt();
        }

        List<List<Integer>> adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) {
            adj.add(new ArrayList<>());
        }
        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj.get(u).add(v);
            adj.get(v).add(u);
        }

        long ans = 0;

        for (int v = 1; v <= n; v++) {
            if (adj.get(v).size() < 2) continue;

            Map<Integer, Integer> neighborCounts = new HashMap<>();
            for (int neighbor : adj.get(v)) {
                neighborCounts.put(a[neighbor], neighborCounts.getOrDefault(a[neighbor], 0) + 1);
            }

            List<Map.Entry<Integer, Integer>> uniqueWeights = new ArrayList<>(neighborCounts.entrySet());

            for (int i = 0; i < uniqueWeights.size(); i++) {
                int w1 = uniqueWeights.get(i).getKey();
                long c1 = uniqueWeights.get(i).getValue();

                if (c1 >= 2) {
                    if (check(a[v], w1, w1, k)) {
                        ans += c1 * (c1 - 1) / 2;
                    }
                }

                for (int j = i + 1; j < uniqueWeights.size(); j++) {
                    int w2 = uniqueWeights.get(j).getKey();
                    long c2 = uniqueWeights.get(j).getValue();
                    if (check(a[v], w1, w2, k)) {
                        ans += c1 * c2;
                    }
                }
            }
        }
        System.out.println(ans);
    }
    
    static class FastReader {
        BufferedReader br; StringTokenizer st;
        public FastReader(){br = new BufferedReader(new InputStreamReader(System.in));}
        String next(){while(st==null||!st.hasMoreElements()){try{st=new StringTokenizer(br.readLine());}catch(IOException e){e.printStackTrace();}}return st.nextToken();}
        int nextInt(){return Integer.parseInt(next());}
        long nextLong(){return Long.parseLong(next());}
    }
}
import sys
from collections import defaultdict

MAX_VAL = 1000001
spf = list(range(MAX_VAL))
factorizations = [None] * MAX_VAL

def sieve():
    for i in range(2, int(MAX_VAL**0.5) + 1):
        if spf[i] == i: # i is prime
            for j in range(i * i, MAX_VAL, i):
                if spf[j] == j:
                    spf[j] = i

def get_factors(n):
    factors = defaultdict(int)
    if n == 1:
        return {}
    temp = n
    while temp != 1:
        p = spf[temp]
        factors[p] += 1
        temp //= p
    return dict(factors)

def check(v_val, w1, w2, k):
    total_factors_map = defaultdict(int)
    
    for p, e in factorizations[v_val].items():
        total_factors_map[p] += e
    for p, e in factorizations[w1].items():
        total_factors_map[p] += e
    for p, e in factorizations[w2].items():
        total_factors_map[p] += e
        
    num_divs = 1
    for exp in total_factors_map.values():
        num_divs *= (exp + 1)
        if num_divs >= k:
            return True
    return num_divs >= k

def main():
    # It's recommended to use PyPy for this kind of problem in Python
    sieve()
    for i in range(1, MAX_VAL):
        factorizations[i] = get_factors(i)

    try:
        n, k = map(int, sys.stdin.readline().split())
        a = [0] + list(map(int, sys.stdin.readline().split()))
        
        adj = [[] for _ in range(n + 1)]
        for _ in range(n - 1):
            u, v = map(int, sys.stdin.readline().split())
            adj[u].append(v)
            adj[v].append(u)
    except (IOError, ValueError):
        return

    ans = 0
    for v in range(1, n + 1):
        if len(adj[v]) < 2:
            continue

        neighbor_counts = defaultdict(int)
        for neighbor in adj[v]:
            neighbor_counts[a[neighbor]] += 1
        
        unique_weights = list(neighbor_counts.items())

        for i in range(len(unique_weights)):
            w1, c1 = unique_weights[i]
            if c1 >= 2:
                if check(a[v], w1, w1, k):
                    ans += c1 * (c1 - 1) // 2
            
            for j in range(i + 1, len(unique_weights)):
                w2, c2 = unique_weights[j]
                if check(a[v], w1, w2, k):
                    ans += c1 * c2
    
    print(ans)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:筛法预处理质数、图遍历、组合计数
  • 时间复杂度,其中 是最大权值 (), 是节点 的邻居中不同权值的数量。
    • 筛法预处理质数部分为
    • 为所有权值预计算质因数分解约为
    • 主循环中,对每个中心节点 v,最耗时的部分是遍历不同权值邻居的对,其复杂度与 v 的邻居的不同权值数的平方成正比。在大多数情况下,这个值远小于节点的度数,使得该算法能够通过。
  • 空间复杂度。主要开销在于存储 所有数的质因数分解和树的邻接表。