题目链接:https://ac.nowcoder.com/acm/contest/897/C
题目大意:给你一棵树,问你距离为3的倍数的点对有多少个。(u, v)和(v, u)为同一个。
思路:
一直没有做过树上dp,因为这题当时只差一点就做出来了,所以补了一下。
当时的思路:
用dp[i][j]表示以i为根所有子树的点到i的距离%3=j的个数。
那么dp[u][(i+w)%3]=∑ dp[v][i] (v为u的子树根节点,w为边u->v的距离)。
后来发现:子树和子树也能产生点对。
那么结果就是:∑ dp[i][j]+所有节点的子树点间产生的点对。
对与每一个子节点间产生的点对,可以用二重循环暴力枚举计数。
for(int i=head[u];i!=-1;i=e[i].next)
{
for(int j=e[i].next;j!=-1;j=e[j].next)
{
ans+=dp[i][0]*dp[j][0];
ans+=dp[i][1]*dp[j][2];
ans+=dp[i][2]*dp[j][1];
}
}
因为1<n<=1e5,所以暴力的复杂度可以达到O(n^2),不可取。
这里有个技巧,我们先看个题:
有n个班级,每个班级有a[i]个学生,现在要推选两名升旗手,必须来自不同的班级,问有多少种不同的组合。
O(n^2)算法:
for(int i=0;i<n;i++)
{
for(int j=i+1;j<n;j++)
{
ans+=a[i]*a[j];
}
}
如果n=4
ans= a[1]*a[2] + a[1]*a[3] + a[1]*a[4]
+a[2]*a[3] + a[2]*a[4]
+a[3]*a[4]
如果我们化简一下:
ans= (a[1])*a[2]
+(a[1]+a[2])*a[3]
+(a[1]+a[2]+a[3])*a[4]
括号里面就是前缀和,所以得到O(n)算法
int s=0;
for(int i=0;i<n;i++)
{
ans=s*a[i];
s+=a[i];
}
这样题目就可以做了,把前面的子树合并,与新的子树计算节点的子树点间新产生的点对。
对代码比较难理解的四个部分,进行了图解:
1:这个时候dp[v]并不能代表u-v2的子树,因为w还没有加入dp[v],所以1的作用就是把w加入dp[v]使其成为子树u->v2。
现在相当于这样:
2:加上u和v2子树中的点满足条件的点对数量。
3:加上v1子树中的点和v2子树中的点满足条件的点的数量
4:把v1和v2合并成v1(类似前缀和)
再dfs(v3)…
#include<bits/stdc++.h>
#define LL long long
using namespace std;
struct node{
int v;
int w;
int next;
}e[200005];
int head[100005], cut=-1;
LL dp[100005][3], ans=0;
void into()
{
cut=-1, ans=0;
memset(head, -1, sizeof(head));
memset(dp, 0, sizeof(dp));
}
void add(int u,int v, int w)
{
cut++;
e[cut].v=v;
e[cut].w=w;
e[cut].next=head[u];
head[u]=cut;
cut++;
e[cut].v=u;
e[cut].w=w;
e[cut].next=head[v];
head[v]=cut;
}
void dfs(int w, int u, int fa)
{
for(int i=head[u];i!=-1;i=e[i].next)
{
int v=e[i].v, w=e[i].w;
if(v!=fa)
{
dfs(w, v, u);
LL x[3]={0};
for(int i=0;i<3;i++)
{
x[i]=dp[v][i];
}
for(int i=0;i<3;i++)//1:用v代替这条链
{
dp[v][(i+w)%3]=x[i];
}
dp[v][w]++;
ans+=dp[v][0];//2:得到这条链到u的长度为3的倍数的个数
ans+=dp[u][0]*dp[v][0];//3:以前的链和这个链的组合
ans+=dp[u][1]*dp[v][2];
ans+=dp[u][2]*dp[v][1];
for(int i=0;i<3;i++)//4:链合并
{
dp[u][i]+=dp[v][i];
}
}
}
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
into();
int n, u, v, w;
scanf("%d",&n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d%d",&u,&v,&w);
w%=3;
add(u, v, w);
}
dfs(-1, 1, -1);
printf("%lld\n",ans);
}
return 0;
}