南昌邀请赛 J、 Distance on the tree

https://nanti.jisuanke.com/t/38229

树剖+主席树

先用树剖将树变成线性结构,然后用主席树维护。

1、感觉直接离散化 20w(n+q) 个数字是会内存超限的,(虽然实际上并没有),所以我们可以先离散化 10w(n)个数字 ,然后将剩下的 q 次询问的数字,upper_bound来找。

2、但是可能存在询问的数字比之前离散化的数字还要小的情况,所以我们考虑将 0 加入到离散化的数组中,这样就保证了所有数字离散化出来是非负数。

3、但是主席树维护的范围是正整数,所以我在在离散化和upperbound的时候,将所有离散化后的数字再加上1,这样就解决啦。

#include <bits/stdc++.h>
#define imid int mid=(left+right)/2;
#define lson left,mid
#define rson mid+1,right
using namespace std;
const int MAXN = 100005;
struct edge
{
	int to;
	int nex;
}e[MAXN * 2];
int head[MAXN], tot;
int son[MAXN], deep[MAXN], fa[MAXN], num[MAXN];
int top[MAXN], p[MAXN];
int pos, n, m, q;
void add(int a, int b)
{
	e[tot] = edge{ b,head[a] };
	head[a] = tot++;
}
//树剖
void init()
{
	tot = 0;
	memset(head, -1, sizeof(head));
	pos = 1;
	memset(son, -1, sizeof(son));
}
void dfs1(int u, int pre, int dep)
{
	fa[u] = pre;
	num[u] = 1;
	deep[u] = dep;
	for (int i = head[u]; i + 1; i = e[i].nex)
	{
		int v = e[i].to;
		if (v != pre)
		{
			dfs1(v, u, dep + 1);
			num[u] += num[v];
			if (son[u] == -1 || num[son[u]] < num[v])
				son[u] = v;
		}
	}
}
void dfs2(int u, int sp)
{
	top[u] = sp;
	p[u] = pos++;
	if (son[u] == -1)
		return;
	dfs2(son[u], sp);
	for (int i = head[u]; i + 1; i = e[i].nex)
	{
		int v = e[i].to;
		if (v != fa[u] && v != son[u])
			dfs2(v, v);
	}
}
//主席树
struct node
{
	int l;
	int r;
	int sum;
}tree[MAXN * 20];
int root[MAXN], cnt;
void inits()
{
	root[0] = 0;
	tree[0].l = tree[0].r = tree[0].sum = 0;
	cnt = 1;
}
void build(int num, int& rot, int left, int right)
{
	tree[cnt] = tree[rot];
	rot = cnt++;
	tree[rot].sum++;
	if (left == right)
		return;
	imid;
	if (num <= mid)
		build(num, tree[rot].l, lson);
	else
		build(num, tree[rot].r, rson);
}
int query(int pre, int nex, int num, int left, int right)
{
	int s = tree[tree[nex].l].sum - tree[tree[pre].l].sum;
	imid;
	if (num < mid)
		return query(tree[pre].l, tree[nex].l, num, lson);
	else if (num > mid)
		return s + query(tree[pre].r, tree[nex].r, num, rson);
	else
		return s;
}

int change(int u, int v, int val, int all)
{
	int ans = 0;
	int f1 = top[u], f2 = top[v];
	while (f1 != f2)
	{
		if (deep[f1] < deep[f2])
		{
			swap(f1, f2);
			swap(u, v);
		}
		ans += query(root[p[f1] - 1], root[p[u]], val, 1, all);
		u = fa[f1];
		f1 = top[u];
	}
	if (u == v)
		return ans;
	if (deep[u] > deep[v])
		swap(u, v);
	ans += query(root[p[son[u]] - 1], root[p[v]], val, 1, all);
	return ans;
}

int in[MAXN][3], t[MAXN], a[MAXN];
int main()
{
	init();
	scanf("%d%d", &n, &q);
	for (int i = 0; i < n - 1; i++)
	{
		scanf("%d%d%d", &in[i][0], &in[i][1], &in[i][2]);
		add(in[i][0], in[i][1]);
		add(in[i][1], in[i][0]);
		t[i] = in[i][2];
	}
	dfs1(1, 0, 0);
	dfs2(1, 1);
	t[n - 1] = 0;
	sort(t, t + n);
	int all = unique(t, t + n) - t;
	for (int i = 0; i < n - 1; i++)
	{
		if (deep[in[i][0]] > deep[in[i][1]])
			swap(in[i][0], in[i][1]);
		a[p[in[i][1]]] = lower_bound(t, t + all, in[i][2]) - t + 1;
	}
	inits();
	for (int i = 1; i <= n; i++)
	{
		root[i] = root[i - 1];
		build(a[i], root[i], 1, all + 10);
	}
	while (q--)
	{
		int u, v, val;
		scanf("%d%d%d", &u, &v, &val);
		val = upper_bound(t, t + all, val) - t;
		int res = change(u, v, val, all + 10);
		printf("%d\n", res);
	}
}