Weak Pair

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 5327    Accepted Submission(s): 1543


 

Problem Description

You are given a rooted tree of N nodes, labeled from 1 to N . To the i th node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
  (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
  (2) au×av≤k .

Can you find the number of weak pairs in the tree?

 

 

Input

There are multiple cases in the data set.
  The first line of input contains an integer T denoting number of test cases.
  For each case, the first line contains two space-separated integers, N and k , respectively.
  The second line contains N space-separated integers, denoting a1 to aN .
  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v .

  Constrains:
  
  1≤N≤105
  
  0≤ai≤109
  
  0≤k≤1018

 

 

Output

For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.

 

 

Sample Input


 

1 2 3 1 2 1 2

 

 

Sample Output


 

1

 

 

        大体题意就是找到多找个(u,v)得组合使value[u]*value[v]<=k;其中u是v得祖先; 

        首先我们可以想到,求这种组合方式我们可以通过dfs,查询到某点时,看看对于这个v点而言能让组合成立的(u,v)到底有多少种。这样的话,我们可以想到用树状数组,查询价值小于一个数(k/value[v])的祖先点到底有几个。

        这样的话又出现了一个问题,就是数据范围太大了,于是我们可以先把每个点v的value以及能使他们相乘<=k的值need[v]预处理出来,然后进行离散化就可以了。

 

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int maxn = 200005;
const int inf = 0x3f3f3f3f;
#define ll long long int 
int n, in[maxn], c[maxn];
ll k, ans, v[maxn], need[maxn], sot[maxn];
vector<int>G[maxn];
void init() {
	for (int s = 1; s <= n; s++)
		G[s].clear();
	memset(c, 0, sizeof(c));
	memset(in, 0, sizeof(in));
	ans = 0;
}
int lowbit(int x) {
	return x&(-x);
}
int sum(int x) {
	int ans = 0;
	while (x) {
		ans += c[x];
		x -= lowbit(x);
	}
	return ans;
}
void add(int x, int d) {
	while (x <= 2 * n) {
		c[x] += d;
		x += lowbit(x);
	}
}
void dfs(int x) {
	ans += sum(need[x]);
	add(v[x], 1);
	int sz = G[x].size();
	for (int s = 0; s < sz; s++)
		dfs(G[x][s]);
	add(v[x], -1);
}
int main() {
	int te;
	scanf("%d", &te);
	while (te--) {
		scanf("%d%lld", &n, &k);
		init();
		for (int s = 1; s <= n; s++) {
			scanf("%lld", &v[s]);
			if (v[s] == 0)
				need[s] = 1e18;
			else 
				need[s] = k / v[s];
			sot[s - 1] = v[s];
			sot[n + s - 2] = need[s];
		}
		sort(sot, sot + 2 * n);
		for (int s = 1; s <= n; s++) {
			v[s] = lower_bound(sot, sot + 2 * n, v[s]) - sot + 1;
			need[s]= lower_bound(sot, sot + 2 * n, need[s]) - sot + 1;
		}
		for (int s = 1; s < n; s++) {
			int a, b;
			scanf("%d%d", &a, &b);
			G[a].push_back(b);
			in[b] = 1;
		}
		for (int s = 1; s <= n; s++) {
			if (!in[s]) {
				dfs(s);
				break;
			}
		}
		printf("%lld\n", ans);
	}
}