题目链接:https://ac.nowcoder.com/acm/contest/897/C
题目大意:给你一棵树,问你距离为3的倍数的点对有多少个。(u, v)和(v, u)为同一个。

思路:

一直没有做过树上dp,因为这题当时只差一点就做出来了,所以补了一下。

当时的思路:
用dp[i][j]表示以i为根所有子树的点到i的距离%3=j的个数。
那么dp[u][(i+w)%3]=∑ dp[v][i] (v为u的子树根节点,w为边u->v的距离)。

后来发现:子树和子树也能产生点对。

那么结果就是:∑ dp[i][j]+所有节点的子树点间产生的点对。

对与每一个子节点间产生的点对,可以用二重循环暴力枚举计数。

for(int i=head[u];i!=-1;i=e[i].next)
{
    for(int j=e[i].next;j!=-1;j=e[j].next)
    {
        ans+=dp[i][0]*dp[j][0];
        ans+=dp[i][1]*dp[j][2];
        ans+=dp[i][2]*dp[j][1];
    }
}

因为1<n<=1e5,所以暴力的复杂度可以达到O(n^2),不可取。

这里有个技巧,我们先看个题:
有n个班级,每个班级有a[i]个学生,现在要推选两名升旗手,必须来自不同的班级,问有多少种不同的组合。

O(n^2)算法:
for(int i=0;i<n;i++)
{
    for(int j=i+1;j<n;j++)
    {
        ans+=a[i]*a[j];
    }
}

如果n=4
ans= a[1]*a[2] + a[1]*a[3] + a[1]*a[4]
	+a[2]*a[3] + a[2]*a[4]
	+a[3]*a[4]
如果我们化简一下:
ans= (a[1])*a[2]
    +(a[1]+a[2])*a[3]
    +(a[1]+a[2]+a[3])*a[4]
    
括号里面就是前缀和,所以得到O(n)算法
int s=0;
for(int i=0;i<n;i++)
{
    ans=s*a[i];
    s+=a[i];
}
   
这样题目就可以做了,把前面的子树合并,与新的子树计算节点的子树点间新产生的点对。

对代码比较难理解的四个部分,进行了图解:
1:这个时候dp[v]并不能代表u-v2的子树,因为w还没有加入dp[v],所以1的作用就是把w加入dp[v]使其成为子树u->v2。

现在相当于这样:

2:加上u和v2子树中的点满足条件的点对数量。
3:加上v1子树中的点和v2子树中的点满足条件的点的数量
4:把v1和v2合并成v1(类似前缀和)

再dfs(v3)…

#include<bits/stdc++.h>
#define LL long long
using namespace std;

struct node{
    int v;
    int w;
    int next;
}e[200005];

int head[100005], cut=-1;
LL dp[100005][3], ans=0;

void into()
{
    cut=-1, ans=0;
    memset(head, -1, sizeof(head));
    memset(dp, 0, sizeof(dp));
}

void add(int u,int v, int w)
{
    cut++;
    e[cut].v=v;
    e[cut].w=w;
    e[cut].next=head[u];
    head[u]=cut;
    cut++;
    e[cut].v=u;
    e[cut].w=w;
    e[cut].next=head[v];
    head[v]=cut;
}

void dfs(int w, int u, int fa)
{
    for(int i=head[u];i!=-1;i=e[i].next)
    {
        int v=e[i].v, w=e[i].w;
        if(v!=fa)
        {
            dfs(w, v, u);
            
            LL x[3]={0};
            for(int i=0;i<3;i++)
            {
                x[i]=dp[v][i];
            }
            for(int i=0;i<3;i++)//1:用v代替这条链
            {

                dp[v][(i+w)%3]=x[i];
            }
            dp[v][w]++;
            ans+=dp[v][0];//2:得到这条链到u的长度为3的倍数的个数

            ans+=dp[u][0]*dp[v][0];//3:以前的链和这个链的组合
            ans+=dp[u][1]*dp[v][2];
            ans+=dp[u][2]*dp[v][1];

            for(int i=0;i<3;i++)//4:链合并
            {
                dp[u][i]+=dp[v][i];
            }
        }
    }
}



int main()
{
    int t;
    scanf("%d",&t);
    while(t--)
    {
        into();
        int n, u, v, w;
        scanf("%d",&n);
        for(int i=0;i<n-1;i++)
        {
            scanf("%d%d%d",&u,&v,&w);
            w%=3;
            add(u, v, w);
        }
        dfs(-1, 1, -1);
        printf("%lld\n",ans);
    }

    return 0;
}