wa 这题重写了一遍过了 我第一次写的什么狗屎啊
题目 : https://nanti.jisuanke.com/t/38229
如下图
对每个 树节点 以他向后面每个链 建主席树
这样只要跑lca 就能区间快查k大了
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 100000 + 10;
int head[maxn], depth[maxn], cnt;
int shu;
int dis[maxn << 1];
int nxt[maxn << 1], to[maxn << 1];
int fa[maxn][30], lg[maxn];
int n, m, k, q;
void ade(int a, int b, int c) {
to[++cnt] = b;
nxt[cnt] = head[a];
dis[cnt] = c;
head[a] = cnt;
}
struct node {
int lc, rc;
int val;
} tree[maxn * 20];
int tot, root[maxn];
int b[maxn];
int build(int l, int r) {
int p = ++ tot;
if(l == r) {
tree[p].val = 0;
return p;
}
int mid = l + r >> 1;
tree[p].lc = build(l, mid);
tree[p].rc = build(mid + 1, r);
tree[p].val = tree[tree[p].lc].val + tree[tree[p].rc].val;
return p;
}
int ins(int now, int l, int r, int pos, int val) {
int p = ++ tot;
tree[p] = tree[now];
if(l == r) {
tree[p].val += val ;
return p;
}
int mid = l + r >> 1;
if(pos <= mid)
tree[p].lc = ins(tree[now].lc, l, mid, pos, val);
else
tree[p].rc = ins(tree[now].rc, mid + 1, r, pos, val);
tree[p].val = tree[tree[p].lc].val + tree[tree[p].rc].val;
return p;
}
int ask(int p, int q, int l, int r, int L, int R) {
if(R < L) return 0;
if(L <= l && r <= R)
return tree[p].val - tree[q].val;
int mid = l + r >> 1;
int res = 0;
if(L <= mid)
res += ask(tree[p].lc, tree[q].lc, l, mid, L, R);
if(R > mid)
res += ask(tree[p].rc, tree[q].rc, mid + 1, r, L, R);
return res;
}
int getid(int x) {
int res = upper_bound(b + 1, b + shu, x) - b - 1;
return res;
}
void dfs(int x, int pre) {
depth[x] = depth[pre] + 1;
fa[x][0] = pre;
for(int i = 1; (1 << i) <= depth[x]; i++)
fa[x][i] = fa[fa[x][i - 1]][i - 1];
for(int i = head[x]; i; i = nxt[i])
if(to[i] != pre) {
root[to[i]] = ins(root[x], 1, shu, getid(dis[i]), 1);
dfs(to[i], x);
}
}
int lca(int x, int y) {
if(depth[x] < depth[y])
swap(x, y);
while(depth[x] > depth[y])
x = fa[x][lg[depth[x] - depth[y]] - 1];
if(x == y)
return x;
for(int k = lg[depth[x]] - 1; k >= 0; k--) {
if(fa[x][k] != fa[y][k]) {
x = fa[x][k], y = fa[y][k];
}
}
return fa[x][0];
}
void lsh() {
sort(b + 1, b + n);
shu = unique(b + 1, b + n) - b;
}
int main() {
for(int i = 1; i < maxn; i++)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
scanf("%d %d", &n, &m);
for(int i = 1; i < n; i ++ ) {
int st, ed, dis;
cin >> st >> ed >> dis;
ade(st, ed, dis),ade(ed, st, dis);
b[i] = dis;
}
lsh();
root[1] = build(1, shu);
dfs(1,0);
for(int i = 1; i <= m; i ++) {
int x, y, k;
scanf( "%d %d %d" , &x, &y, &k);
k = getid(k);
int LCA = lca(x, y);
printf( "%d\n",\
ask(root[y], root[LCA], 1, shu, 1, k)\
+ ask(root[x], root[LCA], 1, shu, 1, k)\
);
}
return 0;
}