牛客2927E - 树
题意
- 有一棵 n 个节点的树,每条边长度为 1,设 dis(u,v) 为 u 到 v 的距离。
- 求 ∑i=1n∑j=1ndis2(i,j)
- 答案对 998244353 取模。
思路
如果将此题改为:求 ∑i=1n∑j=1ndis(i,j)
- 从边下手。对于某一条边,它对答案的贡献为:边权×该边上方点的数量×该边下方点的数量
- 之所以能这么做,是因为 dis(u,v) 的计算是线性的。
- 详细来说,就是能写成 dis(u,v)=dis(u,x)+dis(x,v) 的形式
现在,求 ∑i=1n∑j=1ndis2(i,j)
- 显然,dis2(u,v)=[dis(u,x)+dis(x,v)]2,不能直接相加了。
- 不能直接相加,我们可以通过对等式进行变换,使其变成线性的
- 实际上,dis2(u,v)=dis2(u,x)+dis2(x,v)+2dis(u,x)dis(x,v),这个等式是线性的。
- 这意味着,对于上式中提到的中间节点 x,如果它的一边连着节点 u,另一边连着节点 v,
- 那么,上面式子的那三项直接相加,就是 dis2(u,v)。
- 推广一下,如果中间LCA节点 x 的两边有多个节点,需要计算两两之间的 dis2,那么维护每个节点到所有下方节点的 dis2总和 和 2×disx×disy总和就可以了。
- 我们从每个节点下手,进行树形DP。
- 这样的话,我们需要使用两个DP数组dp1和dp2,分别维护 当前LCA节点到所有下方节点的距离总和∑dis 和 当前LCA节点到所有下方节点的距离平方的总和∑dis2
- 方法:
- dp1[cur]=∑(dp1[nxt]+sz[nxt])
- 对于当前节点 cur 的某个下方节点 u,设cur的父节点为fa,那么dis(fa,u)=dis(cur,u)+1,
- 注意观察 (x+1)2−x2=2x+1
- 那么dis2(fa,u)=dis2(cur,u)+2×dis(cur,u)+1
- 推广到这里就是:dp2[cur]=∑nxt(dp2[nxt]+2×dp1[nxt]+sz[nxt])
如何统计答案呢?
- 对于一个节点 cur,先计算 ∑dis2(cur,cur子树所有节点)
- 再计算 cur 作为LCA时的情况。
代码
#include <cstdio>
#include <iostream>
#include <vector>
#define int long long
const int N = 1e6+10;
const int MOD = 998244353;
using namespace std;
int dp1[N],dp2[N];
int sz[N];
vector<int> G[N],W[N];
int n, ans=0;
void DFS1(int cur,int pre)
{
sz[cur] = 1;
for(auto nxt:G[cur])
{
if(nxt == pre)continue;
DFS1(nxt, cur);
sz[cur] += sz[nxt];
dp1[cur] += (dp1[nxt]+sz[nxt])%MOD, dp1[cur]%=MOD;
dp2[cur] += (dp2[nxt] + ((dp1[nxt]+sz[nxt])%MOD*2%MOD-sz[nxt]+MOD)%MOD)%MOD, dp2[cur]%=MOD;
}
}
void DFS2(int cur,int pre)
{
for(auto nxt:G[cur])
{
if(nxt == pre)continue;
DFS2(nxt, cur);
}
ans += dp2[cur]*2%MOD, ans%=MOD;
for(auto nxt:G[cur])
{
if(nxt == pre)continue;
int other1 = (dp1[cur] - (dp1[nxt]+sz[nxt])%MOD + MOD)%MOD;
int tmp1 = 2 * (dp1[nxt]+sz[nxt]) % MOD * other1 % MOD;
int other2 = (dp2[cur] - dp2[nxt] - ((dp1[nxt]+sz[nxt])%MOD*2%MOD+MOD-sz[nxt]+MOD)%MOD + MOD)%MOD;
int tmp2 = (dp2[nxt]+(dp1[nxt]+sz[nxt])%MOD*2%MOD-sz[nxt]+MOD)%MOD;
ans += (sz[cur]-1-sz[nxt]) * tmp2 % MOD + sz[nxt]*other2 % MOD + tmp1, ans%=MOD;
}
}
void Solve()
{
DFS1(1, 0);
DFS2(1, 0);
printf("%lld\n",ans);
}
signed main()
{
scanf("%lld",&n);
for (int i=1; i<=n-1; i++)
{
int u,v;
scanf("%lld %lld",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
Solve();
return 0;
}