题目链接:https://nanti.jisuanke.com/t/42586
题目大意:
图片说明
思路:我们考虑如果是dis(x, y)==k。直接STL树上启发式合并。map<pair<int, int>, int> map[x]:x节点的子树的深度,和值的节点个数。

如果是<=k。那么查询的深度就是一个范围。我们用轻重链启发式合并。对每个权值建立线段树。下标为深度。维护节点个数和。

#pragma GCC optimize(2)
#include<bits/stdc++.h>
#define LL long long
using namespace std;

inline int read(){
   int s=0,w=1;
   char ch=getchar();
   while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
   while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
   return s*w;
}

struct setTree {
    int lc[5000050], rc[5000050], s[5000050], tot=1;

    inline void Delete(int n, vector<int> qk){
        for(int i=100005; i<=tot; i++){
            lc[i]=rc[i]=s[i]=0;
        }
        //清空用到的根
        for(auto x: qk){
            lc[x]=rc[x]=s[x]=0;
        }
        //预留n个节点,作为n个权值的根
        tot=n;
    }

    inline void up_data(int &i, int l, int r, int x, int v){
        //cout<<i<<" "<<l<<" "<<r<<" "<<x<<" "<<v<<endl;
        if(i==0){
            i=++tot;
        }
        if(l==r){
            s[i]+=v;
            return ;
        }
        int mid=(l+r)>>1;
        if(x<=mid){
            up_data(lc[i], l, mid, x, v);
        }
        else{
            up_data(rc[i], mid+1, r, x, v);
        }
        s[i]=s[lc[i]]+s[rc[i]];
    }

    inline int query(int i, int l, int r, int L, int R){
        if(i==0||R<L){
            return 0;
        }
        if(l==L&&r==R){
            return s[i];
        }
        int mid=(l+r)>>1;
        if(R<=mid) return query(lc[i], l, mid, L, R);
        else if(L>mid) return query(rc[i], mid+1, r, L, R);
        return query(lc[i], l, mid, L, mid)+query(rc[i], mid+1, r, mid+1, R);
    }
}Tree;

int v[100005], w[100005];
int n, k;
LL ans=0;
struct Treedsu{
    vector<int> G[100005], qk;
    int s[100005], dfn[100005], son[100005], deep[100005], T=0;
    inline void dfs(int u, int fa){
        s[u]=1; dfn[u]=++T, deep[T]=deep[dfn[fa]]+1; w[T]=v[u];
        for(auto x: G[u]){
            if(x!=fa){
                dfs(x, u); s[u]+=s[x];
                son[u]=s[x]>s[son[u]]?x:son[u];
            }
        }
    }
    //添加
    inline void add(int x){
        qk.push_back(w[x]);
        Tree.up_data(w[x], 1, 100005, deep[x], 1);
    }
    //计算贡献
    inline LL getans(int u, int x){
        LL ans=0;
        int s1=2*(w[u]-1)-(w[x]-1)+1;
        int r=k+2*deep[u]-deep[x];
        r=min(r, 100005);
        if(s1>=1&&s1<=100005){
            ans+=Tree.query(s1, 1, 100005, 1, r);
        }

        return ans;
    }
    inline void Dsu(int u, int fa, int ok){
        //把除了重链的所有的子树的答案计算出来
        for(auto x: G[u]){
            if(x!=fa&&x!=son[u]){
                Dsu(x, u, 0);
            }
        }

        //计算u这棵树重链的答案
        if(son[u]) Dsu(son[u], u, 1);
        for(auto x: G[u]){
            if(x!=fa&&x!=son[u]){
                //计算轻链的答案
                for(int i=0; i<s[x]; i++){
                    ans+=getans(dfn[u], dfn[x]+i);
                }

                //合并轻链
                for(int i=0; i<s[x]; i++){
                    add(dfn[x]+i);
                }
            }
        }

        add(dfn[u]);//把u自己加进去

        //是否保留线段树
        if(!ok){
            Tree.Delete(100005, qk);
            qk.clear();
        }
    }

}dsu;

int main() {

    int x;
    scanf("%d%d", &n, &k);
    for(int i=1; i<=n; i++){
        v[i]=read();
        //权值可以为0,线段树节点>0
        v[i]++;
    }
    for(int i=2; i<=n; i++){
        x=read();
        dsu.G[i].push_back(x);
        dsu.G[x].push_back(i);
    }
    Tree.Delete(100005, dsu.qk);
    dsu.dfs(1, 0);
    dsu.Dsu(1, 0, 1);//多样例不保留
    printf("%lld\n", ans*2);

    return 0;
}