题目链接

题意

一个有边权的树,提问两点间路径上的边权小于等于某个数的个数

题解

树上主席树模板题,每个节点从他的父亲继承即可

代码

#include<bits/stdc++.h>
#define N 100010
#define INF 0x3f3f3f3f
#define eps 1e-7
#define pi 3.141592653589793
#define mod 998244353
// #define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x,y) memset(x,0,sizeof(int)*(y+3))
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// uniform_real_distribution<double> dis(-1.0,1.0);

typedef pair<int,int> pp;
vector<pp> G[N],a[N];
int n,m,rt[N],u[N<<1],v[N<<1],f[N<<1],z[N<<1],t[N<<1],vis[N],tot,fa[N],rs[N<<5],ls[N<<5],sum[N<<5],cnt;

int getfa(int x){
    return fa[x]==x?x:fa[x]=getfa(fa[x]);
}

void ins(int &i,int old,int x,int l=1,int r=tot){
    i=++cnt; ls[i]=ls[old],rs[i]=rs[old],sum[i]=sum[old]+1;
    if (l==r) return;
    int mid=l+r>>1;
    if (x<=mid) ins(ls[i],ls[old],x,l,mid);else
    ins(rs[i],rs[old],x,mid+1,r);
}

int query(int j,int i,int x,int l=1,int r=tot){
    int mid,t=0;
    while(l<r){
        mid=l+r>>1;
        if (mid>=x) i=ls[i],j=ls[j],r=mid;
        else t+=sum[ls[j]]-sum[ls[i]],i=rs[i],j=rs[j],l=mid+1;
    }
    return t+sum[j]-sum[i];
}

void dfs(int x,int ffa){
    for(auto i:G[x]) if (i.fi!=ffa){
        ins(rt[i.fi],rt[x],i.se);
        dfs(i.fi,x);
        fa[i.fi]=x;
    }
    for(auto i:a[x]) if (vis[i.fi])
        t[i.se]=getfa(i.fi);
    vis[x]=1;
}

int main(int argc, char const *argv[])
{
    scc(n,m);
    for(int i=1;i<=n;i++) fa[i]=i;
    for(int i=1;i<n+m;i++){
        sccc(u[i],v[i],z[i]);
        f[++tot]=z[i];
    }
    sort(f+1,f+tot+1);
    tot=unique(f+1,f+tot+1)-f-1;
    for(int i=1;i<n+m;i++){
        z[i]=lb(f+1,f+tot+1,z[i])-f;
        if (i<n)
            G[u[i]].pb(pp(v[i],z[i])),
            G[v[i]].pb(pp(u[i],z[i]));
        else
            a[u[i]].pb(pp(v[i],i)),
            a[v[i]].pb(pp(u[i],i));
    }
    dfs(1,0);
    for(int i=n;i<n+m;i++){
        if (u[i]==v[i]) puts("0");else
            printf("%d\n",query(rt[u[i]],rt[t[i]],z[i])+query(rt[v[i]],rt[t[i]],z[i]));
    }
    return 0;
}