牛客2927E - 树

题意

  • 有一棵 nn 个节点的树,每条边长度为 11,设 dis(u,v)dis(u,v)uuvv 的距离。
  • i=1nj=1ndis2(i,j)\sum_{i=1}^{n}\sum_{j=1}^{n}dis^2(i,j)
  • 答案对 998244353998244353 取模。

思路

如果将此题改为:求 i=1nj=1ndis(i,j)\sum_{i=1}^{n}\sum_{j=1}^{n}dis(i,j)

  • 从边下手。对于某一条边,它对答案的贡献为:××边权\times该边上方点的数量\times该边下方点的数量
  • 之所以能这么做,是因为 dis(u,v)dis(u,v) 的计算是线性的。
  • 详细来说,就是能写成 dis(u,v)=dis(u,x)+dis(x,v)dis(u,v)=dis(u,x)+dis(x,v) 的形式

现在,求 i=1nj=1ndis2(i,j)\sum_{i=1}^{n}\sum_{j=1}^{n}dis^2(i,j)

  • 显然,dis2(u,v)=[dis(u,x)+dis(x,v)]2dis^2(u,v)=[dis(u,x)+dis(x,v)]^2,不能直接相加了。
  • 不能直接相加,我们可以通过对等式进行变换,使其变成线性的
  • 实际上,dis2(u,v)=dis2(u,x)+dis2(x,v)+2dis(u,x)dis(x,v)dis^2(u,v)=dis^2(u,x)+dis^2(x,v)+2dis(u,x)dis(x,v),这个等式是线性的。
  • 这意味着,对于上式中提到的中间节点 xx,如果它的一边连着节点 uu,另一边连着节点 vv
  • 那么,上面式子的那三项直接相加,就是 dis2(u,v)dis^2(u,v)
  • 推广一下,如果中间LCA节点 xx 的两边有多个节点,需要计算两两之间的 dis2dis^2,那么维护每个节点到所有下方节点的 dis2dis^2总和 和 2×disx×disy2\times disx\times disy总和就可以了。
  • 我们从每个节点下手,进行树形DP。
  • 这样的话,我们需要使用两个DP数组dp1和dp2,分别维护 当前LCA节点到所有下方节点的距离总和dis\sum dis 和 当前LCA节点到所有下方节点的距离平方的总和dis2\sum dis^2
  • 方法:
  • dp1[cur]=(dp1[nxt]+sz[nxt])dp1[cur]=\sum (dp1[nxt]+sz[nxt])
  • 对于当前节点 cur 的某个下方节点 u,设cur的父节点为fa,那么dis(fa,u)=dis(cur,u)+1dis(fa,u)=dis(cur,u)+1
  • 注意观察 (x+1)2x2=2x+1(x+1)^2-x^2=2x+1
  • 那么dis2(fa,u)=dis2(cur,u)+2×dis(cur,u)+1dis^2(fa,u)=dis^2(cur,u)+2\times dis(cur,u)+1
  • 推广到这里就是:dp2[cur]=nxt(dp2[nxt]+2×dp1[nxt]+sz[nxt])dp2[cur]=\sum_{nxt}(dp2[nxt]+2\times dp1[nxt]+sz[nxt])

如何统计答案呢?

  • 对于一个节点 cur,先计算 dis2(cur,cur)\sum dis^2(cur,cur子树所有节点)
  • 再计算 cur 作为LCA时的情况。

代码

#include <cstdio>
#include <iostream>
#include <vector>
#define int long long
const int N		= 1e6+10;
const int MOD	= 998244353;
using namespace std;

int dp1[N],dp2[N];
int sz[N];
vector<int> G[N],W[N];
int n, ans=0;

void DFS1(int cur,int pre)
{
	sz[cur] = 1;
	
	for(auto nxt:G[cur])
	{
		if(nxt == pre)continue;
		
		DFS1(nxt, cur);
		sz[cur] += sz[nxt];
		
		dp1[cur] += (dp1[nxt]+sz[nxt])%MOD, dp1[cur]%=MOD;
		
		dp2[cur] += (dp2[nxt] + ((dp1[nxt]+sz[nxt])%MOD*2%MOD-sz[nxt]+MOD)%MOD)%MOD, dp2[cur]%=MOD;
		
	}
}

void DFS2(int cur,int pre)
{
	for(auto nxt:G[cur])
	{
		if(nxt == pre)continue;
		
		DFS2(nxt, cur);
	}
	
	ans += dp2[cur]*2%MOD, ans%=MOD;
	for(auto nxt:G[cur])
	{
		if(nxt == pre)continue;
		
		int other1 = (dp1[cur] - (dp1[nxt]+sz[nxt])%MOD + MOD)%MOD;
		int tmp1 = 2 * (dp1[nxt]+sz[nxt]) % MOD * other1 % MOD;
		
		int other2 = (dp2[cur] - dp2[nxt] - ((dp1[nxt]+sz[nxt])%MOD*2%MOD+MOD-sz[nxt]+MOD)%MOD + MOD)%MOD;
		int tmp2 = (dp2[nxt]+(dp1[nxt]+sz[nxt])%MOD*2%MOD-sz[nxt]+MOD)%MOD;
		
		ans += (sz[cur]-1-sz[nxt]) * tmp2 % MOD + sz[nxt]*other2 % MOD + tmp1, ans%=MOD;
	}
}

void Solve()
{
	DFS1(1, 0);
	DFS2(1, 0);
	
	printf("%lld\n",ans);
}

signed main()
{
	scanf("%lld",&n);
	for (int i=1; i<=n-1; i++)
	{
		int u,v;
		scanf("%lld %lld",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	
	Solve();
	
	return 0;
}