Weak PairTime Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)Total Submission(s): 5327 Accepted Submission(s): 1543 Problem Description You are given a rooted tree of N nodes, labeled from 1 to N . To the i th node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
Input There are multiple cases in the data set.
Output For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input 1 2 3 1 2 1 2
Sample Output 1 |
大体题意就是找到多找个(u,v)得组合使value[u]*value[v]<=k;其中u是v得祖先;
首先我们可以想到,求这种组合方式我们可以通过dfs,查询到某点时,看看对于这个v点而言能让组合成立的(u,v)到底有多少种。这样的话,我们可以想到用树状数组,查询价值小于一个数(k/value[v])的祖先点到底有几个。
这样的话又出现了一个问题,就是数据范围太大了,于是我们可以先把每个点v的value以及能使他们相乘<=k的值need[v]预处理出来,然后进行离散化就可以了。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int maxn = 200005;
const int inf = 0x3f3f3f3f;
#define ll long long int
int n, in[maxn], c[maxn];
ll k, ans, v[maxn], need[maxn], sot[maxn];
vector<int>G[maxn];
void init() {
for (int s = 1; s <= n; s++)
G[s].clear();
memset(c, 0, sizeof(c));
memset(in, 0, sizeof(in));
ans = 0;
}
int lowbit(int x) {
return x&(-x);
}
int sum(int x) {
int ans = 0;
while (x) {
ans += c[x];
x -= lowbit(x);
}
return ans;
}
void add(int x, int d) {
while (x <= 2 * n) {
c[x] += d;
x += lowbit(x);
}
}
void dfs(int x) {
ans += sum(need[x]);
add(v[x], 1);
int sz = G[x].size();
for (int s = 0; s < sz; s++)
dfs(G[x][s]);
add(v[x], -1);
}
int main() {
int te;
scanf("%d", &te);
while (te--) {
scanf("%d%lld", &n, &k);
init();
for (int s = 1; s <= n; s++) {
scanf("%lld", &v[s]);
if (v[s] == 0)
need[s] = 1e18;
else
need[s] = k / v[s];
sot[s - 1] = v[s];
sot[n + s - 2] = need[s];
}
sort(sot, sot + 2 * n);
for (int s = 1; s <= n; s++) {
v[s] = lower_bound(sot, sot + 2 * n, v[s]) - sot + 1;
need[s]= lower_bound(sot, sot + 2 * n, need[s]) - sot + 1;
}
for (int s = 1; s < n; s++) {
int a, b;
scanf("%d%d", &a, &b);
G[a].push_back(b);
in[b] = 1;
}
for (int s = 1; s <= n; s++) {
if (!in[s]) {
dfs(s);
break;
}
}
printf("%lld\n", ans);
}
}