题目链接

https://www.lydsy.com/JudgeOnline/problem.php?id=3653
https://www.luogu.org/problemnew/show/P3899

思路

三个点肯定在1到c的链上
a已经确定
1.b是a的祖先,答案就是(siz[u]-1)*min(dis[u]-1,k)
2.a是b的祖先,要求\(1<=dis[b]-dis[a]<=k\)
\(1+dis[a]<=dis[b]<=k+dis[a]\)
第一问可以快速求出
第二问无脑线段树合并

代码

#include <bits/stdc++.h>
#define ll long long
#define it_ll vector<ll>::iterator
#define it_pair vector<pair<ll,ll> >::iterator
using namespace std;
const ll N=3e5+7;
ll read() {
    ll x=0,f=1;char s=getchar();
    for(;s>'9'||s<'0';s=getchar()) if(s=='-') f=-1;
    for(;s>='0'&&s<='9';s=getchar()) x=x*10+s-'0';
    return x*f;
}
ll n,m,dis[N],siz[N];
vector<pair<ll,ll> > Q[N];
vector<ll> G[N];
void dfs(ll u,ll f) {
    dis[u]=dis[f]+1;
    siz[u]=1;
    for(it_ll it=G[u].begin();it!=G[u].end();++it) {
        if(*it==f) continue;
        dfs(*it,u);
        siz[u]+=siz[*it];
    }
}
namespace seg {
    struct node {
        ll ls,rs,tot;
    }e[N*30];
    ll cnt;
    void insert(ll &rt,ll l,ll r,ll id,ll k) {
        if(!rt) rt=++cnt;
        e[rt].tot+=k;
        if(l==r) return;
        ll mid=(l+r)>>1;
        if(id<=mid) insert(e[rt].ls,l,mid,id,k);
        else insert(e[rt].rs,mid+1,r,id,k);
    }
    ll query(ll rt,ll l,ll r,ll L,ll R) {
        if(L<=l&&r<=R) return e[rt].tot;
        ll mid=(l+r)>>1;
        if(L<=mid&&R>mid) return query(e[rt].ls,l,mid,L,R)+query(e[rt].rs,mid+1,r,L,R);
        if(L<=mid) return query(e[rt].ls,l,mid,L,R);
        if(R>mid) return query(e[rt].rs,mid+1,r,L,R);
    }
    ll merge(ll x,ll y){
        if(!x||!y) return x+y;
        e[x].tot+=e[y].tot;
        e[x].ls=merge(e[x].ls,e[y].ls);
        e[x].rs=merge(e[x].rs,e[y].rs);
        return x;
    }
}
ll rt[N];
ll ans[N];
ll solve(ll u,ll f) {
    seg::insert(rt[u],1,n,dis[u],siz[u]-1);
    for(it_ll it=G[u].begin();it!=G[u].end();++it) {
        if(*it==f) continue;
        solve(*it,u);
        rt[u]=seg::merge(rt[u],rt[*it]);
    }
    for(it_pair it=Q[u].begin();it!=Q[u].end();++it) {
        ans[it->second]=1LL*seg::query(rt[u],1,n,dis[u]+1,dis[u]+it->first)+1LL*(siz[u]-1)*min(dis[u]-1,it->first);
        // cout<<seg::e[seg::e[rt[u]].rs].tot<<" ["<<dis[u]+1<<", "<<dis[u]+it->first<<"]\n";
        // cout<<1LL*seg::query(rt[u],1,n,dis[u]+1,dis[u]+it->first)<<"\n";
    }
}
int main() {
    n=read(),m=read();
    for(ll i=1;i<n;++i) {
        ll x=read(),y=read();
        G[x].push_back(y),G[y].push_back(x);
    }
    for(ll i=1;i<=m;++i) {
        ll p=read(),k=read();
        Q[p].push_back(make_pair(k,i));
    }
    dfs(1,0);
    solve(1,0);
    for(ll i=1;i<=m;++i) printf("%lld\n",ans[i]);
    return 0;
}