Tree

前言

树上启发式合并好啊,可惜还不太会啊图片说明

分析

1.凡事始于朴素:根据题目,我得在每一个子树内部去寻找答案。
图片说明
假设是这样的一棵树。首先在3节点的子树中去找,首先遍历4节点,
图片说明
统计一下,这时的lca是3,记录一个k并求出图片说明 表示另一个数的出现次数,同时图片说明 。然后去到5节点,
图片说明
同样的,首先找到满足图片说明 的数的个数,同时将图片说明
然后进入10号节点....
图片说明
这个时候我们发现已经统计完了3节点的子树,向上走一步同时vis数组清零,到达2号,根据统计3节点子树的方法一样,先把3节点所有子树加入贡献(注意,是把vis清零后再重新跑一遍3节点的子树)
图片说明
然后再进入6节点,统计答案....

2.会发现,有许多不必要的重复操作,我为什么要把3节点子树的vis清零呢?我一定要把所有的都清零吗?于是树上启发式合并就来了。名字听着挺nb的,其实就是最大减少重复操作。就比如,在这棵树中,我如果要减少重复操作,明显就得最小化进入节点较多的子树的次数。也就是说,尽可能多的保留重儿子的信息(不清零),在统计时,只需要搜索轻儿子。
先给出代码

inline void dsu(int u,int v,bool w)
{
    for (int i=h[u];~i;i=nex[i])
    {
        int j=ver[i];
        if(j==v||j==son[u]) continue;
        dsu(j,u,0);
    }
    if(son[u]) dsu(son[u],u,1);

    for (int i=h[u];~i;i=nex[i])
    {
        int j=ver[i];
        if(j==v||j==son[u]) continue;

        cal(j,u,u),upd(j,u,1);
    }

    k[val[u]]++;
    if(!w) upd(u,v,-1);
}

然后看图模拟一遍(以1节点为lca统计答案):首先,重儿子(节点数最多的那个)为2,先跑入7,9节点统计完这两个子树对答案产生的贡献之后,再加入重儿子
图片说明
会发现,此时2节点的子树内部信息不会被清零,
图片说明
然后开始与轻儿子统计对答案的贡献,避免了重新进入2号节点统计每一个值的出现个数的问题。

代码

/*树上启发式合并*/
#include<bits/stdc++.h>

#define R register
#define ll long long
#define inf INT_MAX

using namespace std;

const int N=1e5+10;

int n,tot;ll ans;
int h[N],nex[N<<1],ver[N<<1];
int son[N],val[N],siz[N];

map<int,int>k;

inline void add(int x,int y)
{
    nex[tot]=h[x];
    ver[tot]=y;
    h[x]=tot++;
}

inline void dfs(int u,int v)
{
    siz[u]=1;
    for (int i=h[u];~i;i=nex[i])
    {
        int j=ver[i];
        if(j==v) continue;

        dfs(j,u);

        siz[u]+=siz[j];
        if(siz[son[u]]<siz[j]) son[u]=j;
    }
}

inline void cal(int u,int v,int lca)
{
    ans+=(ll)k[2*val[lca]-val[u]];
    for (int i=h[u];~i;i=nex[i])
        if(ver[i]!=v) cal(ver[i],u,lca);
}

inline void upd(int u,int v,int va)
{
    k[val[u]]+=va;
    for (int i=h[u];~i;i=nex[i])
        if(ver[i]!=v)
            upd(ver[i],u,va);
}

inline void dsu(int u,int v,bool w)
{
    for (int i=h[u];~i;i=nex[i])
    {
        int j=ver[i];
        if(j==v||j==son[u]) continue;
        dsu(j,u,0);
    }
    if(son[u]) dsu(son[u],u,1);

    for (int i=h[u];~i;i=nex[i])
    {
        int j=ver[i];
        if(j==v||j==son[u]) continue;

        cal(j,u,u),upd(j,u,1);
    }

    k[val[u]]++;
    if(!w) upd(u,v,-1);
}

int main()
{
    memset(h,-1,sizeof(h));
    scanf("%d",&n);
    for (int i=1;i<=n;i++) scanf("%d",&val[i]);
    for (int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y),add(y,x);
    }

    dfs(1,0);
    dsu(1,1,0);

    printf("%lld\n",ans*2ll);

    return 0;
}

后话

篇幅极短,且用词不规范,不知是否会误导其他人(应该也没多少人看)