J Distance on the tree
zxy好弱啊,赛中分析好久也没想出来,结果题解出来发现我是个大***,这个题我在刚刚过去的ccpc省赛补过,再次出我就不会了,啊啊啊啊,我好弱啊
https://blog.csdn.net/weixin_38686780/article/details/89295381 C 题,有兴趣的可以看看
题意: 每次在树上查询 从u -v 路径上的边的权值小于k的边有多少个
分析:我们发现每次查询可以化简成,查询每个节点到根有多少个,然后求出u,v的lca, ans(u) +ans(v)-2*ans(lca(u,v)) 就是我们的答案了,所以我们的题转化成查询从节点到根有多少个边权小于k的边,这是什么? 树上主席树啊
具体做法
1 lca初始化,我是用倍增求的lca
2 主席树初始化
3 查询
#include <bits/stdc++.h>
#define lowbit(x) (x&(-x))
using namespace std;
typedef long long LL;
const int maxn = 2e5+10;
#define Pb push_back
typedef pair<int,int> P;
// 主席树
int sum[maxn*20],Left[maxn*20],Right[maxn*20],tot,root[maxn];
int Build(int l,int r){
int rt = (++tot);
if(l == r) return rt;
int m = (l+r)>>1;
Left[rt] = Build(l,m);
Right[rt] = Build(m+1,r);
return rt;
}
int Update(int pre,int l,int r,int x,int p){
int rt = ++tot;
Left[rt] = Left[pre],Right[rt] = Right[pre],sum[rt] = sum[pre]+p;
if(l < r){
int m = (l+r)>>1;
if(x <= m)
Left[rt] = Update(Left[pre],l,m,x,p);
else
Right[rt] = Update(Right[pre],m+1,r,x,p);
}
return rt;
}
int Query(int u,int l,int r,int L,int R){
if(L <= l && R >= r) return sum[u];
int m = (l+r)>>1;
LL ans = 0;
if(L <= m)
ans += Query(Left[u],l,m,L,R);
if(R > m)
ans += Query(Right[u],m+1,r,L,R);
return ans;
}
int Hash[maxn],nn;
vector<P> G[maxn];
void dfs(int node,int fa){
for(auto c: G[node]){
if(c.first == fa) continue;
int t = lower_bound(Hash+1,Hash+nn+1,c.second)-Hash;
root[c.first] = Update(root[node],1,nn,t,1);
dfs(c.first,node);
}
}
const int maxlogv = 17;
int rt;
int parent[maxlogv][maxn];
int depth[maxn];
void dfs(int v,int p,int d){
parent[0][v] = p;
depth[v] = d;
for(int i = 0;i < G[v].size(); ++i){
int u = G[v][i].first;
if(u != p){
dfs(u,v,d+1);
}
}
}
void init(int V){
dfs(rt,-1,0);
for(int k = 0;k+1 < maxlogv; ++k){
for(int v = 1; v <= V; ++v){
if(parent[k][v] < 0) parent[k+1][v] = -1;
else parent[k+1][v] = parent[k][parent[k][v]];
}
}
}
int lca(int u,int v){
if(depth[u] > depth[v]) swap(u,v);
for(int k = 0;k < maxlogv; ++k){
if(((depth[v] - depth[u]) >> k)& 1){
v = parent[k][v];
}
}
if(u == v) return u;
for(int k = maxlogv-1; k >= 0; --k){
if(parent[k][u] != parent[k][v]){
u = parent[k][u];
v = parent[k][v];
}
}
return parent[0][u];
}
int main(void){
int n,m;
cin>>n>>m;
nn = 0;
for(int i = 2;i <= n; ++i){
int u,v,w;scanf("%d%d%d",&u,&v,&w);
G[u].Pb(P(v,w));
G[v].Pb(P(u,w));
Hash[++nn] = w;
}
rt = 1;
init(n);
Hash[++nn] = 1e9+2;
sort(Hash+1,Hash+nn+1);
nn = unique(Hash+1,Hash+nn+1)-(Hash+1);
root[1] = Build(1,nn);
// root[1] = up
dfs(1,0);
// cout<<Query(1,1,nn,2,nn)<<endl;
for(int i = 1;i <= m; ++i){
int u,v,k;scanf("%d%d%d",&u,&v,&k);
k = upper_bound(Hash+1,Hash+nn+1,k)-Hash-1;
// cout<<k<<endl;
if(k == 0)
{
puts("0");
continue;
}
int lc = lca(u,v);
printf("%d\n",Query(root[u],1,nn,1,k)+Query(root[v],1,nn,1,k)-2*Query(root[lc],1,nn,1,k));
}
return 0;
}