题面:
题意:
给定一棵带边权的树,问 (x−y)的路径上权值小于等于 k的边有多少条。
题解:
建立树上主席树, ans=ans(1−x)+ans(1−y)−2∗ans(1−lca(x,y))
其中 ans(1−x)表示,路径 1−−x上面权值小于等于 k的边数,这是主席树的基操。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<set>
#include<deque>
#include<map>
#include<vector>
#include<cmath>
#define ll long long
#define llu unsigned ll
using namespace std;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const int inf = 0x3f3f3f3f;
const int maxn=200100;
const int mod=1e9+7;
struct node
{
int lc,rc;
int val;
}t[maxn*22];
int root[maxn],b[maxn];
int head[maxn],ver[maxn<<1],nt[maxn<<1],edge[maxn<<1];
int f[maxn][20],d[maxn];
int xx[maxn],yy[maxn],zz[maxn];
int u[maxn],v[maxn],k[maxn];
int tot=1,cnt=0,cm,tt,ct;
void init(int n)
{
memset(head,0,sizeof(head));
tot=1,cnt=0,cm=0,tt=0;
root[0]=0;
ct=log(n)/log(2)+1;
}
void add(int x,int y,int z)
{
ver[++tot]=y,edge[tot]=z;
nt[tot]=head[x],head[x]=tot;
}
int change(int now,int pos,int l,int r)
{
int p=++cnt;
t[p]=t[now];
if(l==r)
{
t[p].val++;
return p;
}
int mid=(l+r)>>1;
if(pos<=mid) t[p].lc=change(t[now].lc,pos,l,mid);
else t[p].rc=change(t[now].rc,pos,mid+1,r);
t[p].val=t[t[p].lc].val+t[t[p].rc].val;
return p;
}
int ask(int x,int y,int fa,int nl,int nr,int l,int r)
{
if(nl<=l&&r<=nr)
{
//cout<<t[x].val<<" "<<t[y].val<<endl;
return t[x].val+t[y].val-2*t[fa].val;
}
int mid=(l+r)>>1;
int ans=0;
if(mid>=nl) ans+=ask(t[x].lc,t[y].lc,t[fa].lc,nl,nr,l,mid);
if(mid+1<=nr) ans+=ask(t[x].rc,t[y].rc,t[fa].rc,nl,nr,mid+1,r);
return ans;
}
void dfs(int x,int fa)
{
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
d[y]=d[x]+1;
f[y][0]=x;
for(int j=1;j<=ct;j++)
f[y][j]=f[f[y][j-1]][j-1];
root[y]=change(root[x],edge[i],1,cm);
dfs(y,x);
}
}
int lca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(int i=ct;i>=0;i--)
if(d[f[y][i]]>=d[x]) y=f[y][i];
if(x==y) return x;
for(int i=ct;i>=0;i--)
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int main(void)
{
int n,m;
while(scanf("%d%d",&n,&m)!=EOF)
{
init(n);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&xx[i],&yy[i],&zz[i]);
b[++tt]=zz[i];
}
for(int i=1;i<=m;i++)
{
scanf("%d%d%d",&u[i],&v[i],&k[i]);
b[++tt]=k[i];
}
sort(b+1,b+tt+1);
cm=unique(b+1,b+tt+1)-(b+1);
for(int i=1;i<n;i++)
zz[i]=lower_bound(b+1,b+cm+1,zz[i])-b;
for(int i=1;i<=m;i++)
k[i]=lower_bound(b+1,b+cm+1,k[i])-b;
for(int i=1;i<n;i++)
{
add(xx[i],yy[i],zz[i]);
add(yy[i],xx[i],zz[i]);
}
dfs(1,0);
for(int i=1;i<=m;i++)
{
int fa=lca(u[i],v[i]);
printf("%d\n",ask(root[u[i]],root[v[i]],root[fa],1,k[i],1,cm));
}
}
return 0;
}