题目链接
思路分析
1. 问题建模
题目要求我们寻找一棵树中所有长度为 2 的路径,这些路径由三个节点(我们称之为 u-v-w
)组成,并满足特定条件。
- 路径结构:一条长度为 2 的路径
u-v-w
由一个中心节点v
和它的两个不同邻居u
、w
构成。 - 路径唯一性:每一条这样的路径都由其中心节点
v
和邻居对{u, w}
唯一确定。路径u-v-w
和w-v-u
被视为同一条路径。 - 条件:路径上三个节点权值的乘积
val(u) * val(v) * val(w)
的因子数量必须不小于k
。
基于此,一个直接的思路是:遍历树中的每一个节点,让它充当中心节点 v
,然后考虑其所有邻居对 {u, w}
,检查它们构成的路径是否满足条件。
2. 核心挑战:因子数量计算
计算一个数 X
的因子数量,需要知道它的标准质因数分解。如果 ,那么
X
的因子数量为 。
对于路径 u-v-w
,其权值乘积 。要计算
的因子数,我们需要
的质因数分解。
的质因数分解可以通过合并
val(u)
, val(v)
, val(w)
各自的质因数分解来得到(即对每个质因子,指数相加)。
3. 性能瓶颈与优化
一个朴素的算法是:
- 遍历每个节点
v
(从 0 到 n-1)。 - 获取
v
的邻居列表adj[v]
。 - 如果
adj[v]
的大小小于 2,则v
不能作为中心节点,跳过。 - 遍历所有邻居对
(adj[v][i], adj[v][j])
其中i < j
。 - 对于每一对,计算乘积的因子数并与
k
比较。
这个算法的瓶颈在于步骤 4。如果一个节点 v
的度数为 deg(v)
,我们需要进行 次检查。在最坏的情况下(例如星形图),一个节点的度数可能是
,这会导致
次检查,对于
来说太慢了。
优化思路:在为中心节点 v
检查其邻居对时,我们发现,如果很多邻居的权值是相同的,那么它们是等价的。
我们可以将 v
的邻居按其权值进行分组。
- 假设
v
有c1
个权值为w1
的邻居,c2
个权值为w2
的邻居,等等。 - 同权值邻居对:对于权值为
w1
的c1
个邻居,我们可以从中任选两个构成路径w1-v-w1
。共有种选法。我们只需计算一次
val(v) * w1 * w1
的因子数,如果满足条件,就将加入总答案。
- 不同权值邻居对:对于权值为
w1
的c1
个邻居和权值为w2
的c2
个邻居,我们可以构成c1 * c2
条路径w1-v-w2
。我们只需计算一次val(v) * w1 * w2
的因子数,如果满足条件,就将c1 * c2
加入总答案。
通过这种方式,对于中心节点 v
,需要检查的次数从 deg(v)^2
级别降低到 (v的邻居的不同权值数)^2
级别,这是一个巨大的优化。
4. 算法步骤
- 预计算:
a. 使用筛法预计算出
到
每个数的最小质因子 (SPF)。这可以让我们在
的时间内得到任何数的质因数分解。 b. 读入所有节点的权值,并利用 SPF 数组为每个节点的权值预先计算好其质因数分解,存起来备用。
- 建图:读入边,建立树的邻接表表示。
- 主循环:
a. 初始化总方案数
ans = 0
。 b. 遍历每个节点v
(从 0 到 n-1),将其作为路径中心。 c. 如果v
的度数小于 2,跳过。 d. 邻居分组:创建一个map<int, int>
来统计v
的邻居中每种权值的出现次数。 e. 将map
中的(权值, 次数)
对提取到一个vector
中,方便后续遍历。 f. 组合计数: i. 遍历这个vector
。对于每个(w1, c1)
,如果c1 >= 2
,检查w1-v-w1
路径,若满足条件则ans += c1 * (c1 - 1) / 2
。 ii. 使用一个嵌套循环遍历vector
中的权值对。对于(w1, c1)
和(w2, c2)
,检查w1-v-w2
路径,若满足条件则ans += c1 * c2
。 - 因子数检查函数:
a. 该函数接受三个节点的质因数分解和一个阈值
k
。 b. 它合并三个质因数分解(即将各质因子的指数相加)。 c. 然后计算总因子数。注意,因子数可能非常大,会超出long long
。因此在计算累乘时,需要使用__int128_t
(C++) 或在每次乘法前判断是否会超过k
来避免溢出。 - 输出最终的
ans
。
代码
#include <iostream>
#include <vector>
#include <numeric>
#include <map>
#include <algorithm>
using namespace std;
const int MAX_VAL = 1000001;
vector<int> spf(MAX_VAL);
vector<vector<pair<int, int>>> factorizations;
void sieve() {
iota(spf.begin(), spf.end(), 0);
for (int i = 2; i * i < MAX_VAL; ++i) {
if (spf[i] == i) { // i is prime
for (int j = i * i; j < MAX_VAL; j += i) {
if (spf[j] == j) {
spf[j] = i;
}
}
}
}
}
vector<pair<int, int>> get_factors(int n) {
vector<pair<int, int>> factors;
if (n == 1) return factors;
int current_p = spf[n];
int count = 1;
n /= spf[n];
while (n != 1) {
if (spf[n] == current_p) {
count++;
} else {
factors.push_back({current_p, count});
current_p = spf[n];
count = 1;
}
n /= spf[n];
}
factors.push_back({current_p, count});
return factors;
}
bool check(int v_val, int w1, int w2, long long k) {
map<int, int> total_factors_map;
auto factors_v = factorizations[v_val];
auto factors_1 = factorizations[w1];
auto factors_2 = factorizations[w2];
for (const auto& p : factors_v) total_factors_map[p.first] += p.second;
for (const auto& p : factors_1) total_factors_map[p.first] += p.second;
for (const auto& p : factors_2) total_factors_map[p.first] += p.second;
__int128_t num_divs = 1;
for (const auto& p : total_factors_map) {
num_divs *= (p.second + 1);
if (num_divs >= k) {
return true;
}
}
return num_divs >= k;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
sieve();
factorizations.resize(MAX_VAL);
for (int i = 1; i < MAX_VAL; ++i) {
factorizations[i] = get_factors(i);
}
int n;
long long k;
cin >> n >> k;
vector<int> a(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
vector<vector<int>> adj(n + 1);
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
long long ans = 0;
for (int v = 1; v <= n; ++v) {
if (adj[v].size() < 2) continue;
map<int, int> neighbor_counts;
for (int neighbor : adj[v]) {
neighbor_counts[a[neighbor]]++;
}
vector<pair<int, int>> unique_weights;
for (const auto& p : neighbor_counts) {
unique_weights.push_back(p);
}
for (size_t i = 0; i < unique_weights.size(); ++i) {
int w1 = unique_weights[i].first;
long long c1 = unique_weights[i].second;
if (c1 >= 2) {
if (check(a[v], w1, w1, k)) {
ans += c1 * (c1 - 1) / 2;
}
}
for (size_t j = i + 1; j < unique_weights.size(); ++j) {
int w2 = unique_weights[j].first;
long long c2 = unique_weights[j].second;
if (check(a[v], w1, w2, k)) {
ans += c1 * c2;
}
}
}
}
cout << ans << endl;
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static final int MAX_VAL = 1000001;
static int[] spf = new int[MAX_VAL];
static List<Map<Integer, Integer>> factorizations = new ArrayList<>();
static void sieve() {
for (int i = 0; i < MAX_VAL; i++) spf[i] = i;
for (int i = 2; i * i < MAX_VAL; i++) {
if (spf[i] == i) { // i is prime
for (int j = i * i; j < MAX_VAL; j += i) {
if (spf[j] == j) {
spf[j] = i;
}
}
}
}
}
static Map<Integer, Integer> getFactors(int n) {
Map<Integer, Integer> factors = new HashMap<>();
if (n == 1) return factors;
int temp = n;
while (temp != 1) {
int p = spf[temp];
factors.put(p, factors.getOrDefault(p, 0) + 1);
temp /= p;
}
return factors;
}
static boolean check(int vVal, int w1, int w2, long k) {
Map<Integer, Integer> totalFactorsMap = new HashMap<>();
for (Map.Entry<Integer, Integer> entry : factorizations.get(vVal).entrySet()) {
totalFactorsMap.put(entry.getKey(), totalFactorsMap.getOrDefault(entry.getKey(), 0) + entry.getValue());
}
for (Map.Entry<Integer, Integer> entry : factorizations.get(w1).entrySet()) {
totalFactorsMap.put(entry.getKey(), totalFactorsMap.getOrDefault(entry.getKey(), 0) + entry.getValue());
}
for (Map.Entry<Integer, Integer> entry : factorizations.get(w2).entrySet()) {
totalFactorsMap.put(entry.getKey(), totalFactorsMap.getOrDefault(entry.getKey(), 0) + entry.getValue());
}
long numDivs = 1;
for (int exp : totalFactorsMap.values()) {
// Check for overflow before multiplication
if ((double) numDivs * (exp + 1) >= k) {
return true;
}
numDivs *= (exp + 1);
}
return numDivs >= k;
}
public static void main(String[] args) throws IOException {
FastReader sc = new FastReader();
sieve();
for (int i = 0; i < MAX_VAL; i++) {
factorizations.add(getFactors(i));
}
int n = sc.nextInt();
long k = sc.nextLong();
int[] a = new int[n + 1];
for (int i = 1; i <= n; i++) {
a[i] = sc.nextInt();
}
List<List<Integer>> adj = new ArrayList<>();
for (int i = 0; i <= n; i++) {
adj.add(new ArrayList<>());
}
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj.get(u).add(v);
adj.get(v).add(u);
}
long ans = 0;
for (int v = 1; v <= n; v++) {
if (adj.get(v).size() < 2) continue;
Map<Integer, Integer> neighborCounts = new HashMap<>();
for (int neighbor : adj.get(v)) {
neighborCounts.put(a[neighbor], neighborCounts.getOrDefault(a[neighbor], 0) + 1);
}
List<Map.Entry<Integer, Integer>> uniqueWeights = new ArrayList<>(neighborCounts.entrySet());
for (int i = 0; i < uniqueWeights.size(); i++) {
int w1 = uniqueWeights.get(i).getKey();
long c1 = uniqueWeights.get(i).getValue();
if (c1 >= 2) {
if (check(a[v], w1, w1, k)) {
ans += c1 * (c1 - 1) / 2;
}
}
for (int j = i + 1; j < uniqueWeights.size(); j++) {
int w2 = uniqueWeights.get(j).getKey();
long c2 = uniqueWeights.get(j).getValue();
if (check(a[v], w1, w2, k)) {
ans += c1 * c2;
}
}
}
}
System.out.println(ans);
}
static class FastReader {
BufferedReader br; StringTokenizer st;
public FastReader(){br = new BufferedReader(new InputStreamReader(System.in));}
String next(){while(st==null||!st.hasMoreElements()){try{st=new StringTokenizer(br.readLine());}catch(IOException e){e.printStackTrace();}}return st.nextToken();}
int nextInt(){return Integer.parseInt(next());}
long nextLong(){return Long.parseLong(next());}
}
}
import sys
from collections import defaultdict
MAX_VAL = 1000001
spf = list(range(MAX_VAL))
factorizations = [None] * MAX_VAL
def sieve():
for i in range(2, int(MAX_VAL**0.5) + 1):
if spf[i] == i: # i is prime
for j in range(i * i, MAX_VAL, i):
if spf[j] == j:
spf[j] = i
def get_factors(n):
factors = defaultdict(int)
if n == 1:
return {}
temp = n
while temp != 1:
p = spf[temp]
factors[p] += 1
temp //= p
return dict(factors)
def check(v_val, w1, w2, k):
total_factors_map = defaultdict(int)
for p, e in factorizations[v_val].items():
total_factors_map[p] += e
for p, e in factorizations[w1].items():
total_factors_map[p] += e
for p, e in factorizations[w2].items():
total_factors_map[p] += e
num_divs = 1
for exp in total_factors_map.values():
num_divs *= (exp + 1)
if num_divs >= k:
return True
return num_divs >= k
def main():
# It's recommended to use PyPy for this kind of problem in Python
sieve()
for i in range(1, MAX_VAL):
factorizations[i] = get_factors(i)
try:
n, k = map(int, sys.stdin.readline().split())
a = [0] + list(map(int, sys.stdin.readline().split()))
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v = map(int, sys.stdin.readline().split())
adj[u].append(v)
adj[v].append(u)
except (IOError, ValueError):
return
ans = 0
for v in range(1, n + 1):
if len(adj[v]) < 2:
continue
neighbor_counts = defaultdict(int)
for neighbor in adj[v]:
neighbor_counts[a[neighbor]] += 1
unique_weights = list(neighbor_counts.items())
for i in range(len(unique_weights)):
w1, c1 = unique_weights[i]
if c1 >= 2:
if check(a[v], w1, w1, k):
ans += c1 * (c1 - 1) // 2
for j in range(i + 1, len(unique_weights)):
w2, c2 = unique_weights[j]
if check(a[v], w1, w2, k):
ans += c1 * c2
print(ans)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:筛法预处理质数、图遍历、组合计数
- 时间复杂度:
,其中
是最大权值 (
),
是节点
的邻居中不同权值的数量。
- 筛法预处理质数部分为
。
- 为所有权值预计算质因数分解约为
。
- 主循环中,对每个中心节点
v
,最耗时的部分是遍历不同权值邻居的对,其复杂度与v
的邻居的不同权值数的平方成正比。在大多数情况下,这个值远小于节点的度数,使得该算法能够通过。
- 筛法预处理质数部分为
- 空间复杂度:
。主要开销在于存储
到
所有数的质因数分解和树的邻接表。