一、题目概述
给出一张图,求出 $ans=\sum_{i=1}^n\sum_{j=1}^n[i\ne j]\texttt d^k(i,j)$ 的值,其中 $\texttt d(x,y)$ 表示从 $x$ 到 $y$ 的最短路。对于所有数据满足 $1\le n\le 10^5,1\le k\le 10^9$,保证给定的图 $G$ 满足题中要求,且不存在重边。
$\texttt{subtask 1:}~ 5\%$,满足 $n\le 1000$ 。
$\texttt{subtask 2:}~10\%$,满足 $k=1$ 。
$\texttt{subtask 3:}~15\%$,满足 $k=2$ 。
$\texttt{subtask 4:}~30\%$,满足 $G$ 中存在一条边 $(u,u)$ 。
$\texttt{subtask 5:}~40\%$,无额外限制。
二、解题思路
算法0
首先肯定可以 $\mathcal O(n^3)$ 跑 $\texttt{Floyd}$ ,但是这个是集训队互测的题,可能是由于集训队大佬们都不屑于写,于是就没有这一档部分分。算法1
数据满足 $n\le 1000$ 。考虑分类讨论:
如果这是一棵树的话,那么问题就是求 $ans=\sum_{i=1}^n\sum_{j=1}^n[i\ne j](\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)})^k$ ,直接枚举即可。
如果原图是一棵基环树的话,使用基环树的常规套路,先找出环,并且删去环上的任意一条边,剩下的边按照树的方案做一次。做完之后,我们单独考虑这条边 $(u,v)$ 对最短路的贡献,即有哪些路径可以使用 $s\longrightarrow u\rightarrow v\longrightarrow t$ 作为最短路。于是我们先从 $u$ 和 $v$ 向其它点跑一边最短路,然后 $\mathcal O(n^2)$ 枚举 $s$ 和 $t$ ,然后暴力更新即可。
算法2
数据满足 $k=1$ 。考虑分类讨论:
如果原图是树,那么问题就是求 $ans=\sum_{i=1}^n\sum_{j=1}^n [i\ne j]\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)}$ ,直接枚举每个点作为路径的端点,和作为 $\texttt{LCA}$ 对答案的贡献即可。
如果原图是基环树,那么我们可以分析一下这张图的 `dfs` 树会长成什么样:
这就意味着,新加入的边如果对某一条路径有影响,那么这条路径的一定是一端在 $\texttt{v1}$ 的子树之外,另一端在 $\texttt{vn}$ 的子树之内。通过这个性质,我们就可以将基环树上的问题转化为树上的问题求解了。
这样的话,我们还可以顺便解决另外一个部分分:
数据满足 $k=2$ 。答案变为 $ans=\sum_{i=1}^n\sum_{j=1}^n [i\ne j](\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)})^2$ ,将完全平方展开后利用上面的方法维护即可。
算法3
满足 $G$ 中存在一条边 $(u,u)$ ,即原图为树。考虑到答案中存在乘方操作,不容易计算,于是令 $cnt_i$ 表示树上长度为 $i$ 的路径的个数,那么答案就是 $ans=\sum_{i=1}^ncnt_i\times i^k$。不难发现,$cnt$ 可以使用点分治维护,然后用多项式算法优化复杂度,于是这个部分分的问题就解决了。
算法4
数据满足 $1\le n\le 10^5,1\le k\le 10^9$,保证给定的图 $G$ 满足题中要求,且不存在重边。如果原图是树的话,直接用上面的算法3就好了。
如果是基环树的话,考虑基环DP,先做环以外的点的树形DP,然后在换上合并即可。
#include <iostream> #include <cstdio> #include <queue> #include <vector> using namespace std; const int mod = 998244353, inv2 = 499122177, inv3 = 332748118; vector<int> ve[500005], cnt[500005]; int n, rev[500005], cy[500005], dep[500005], fa[500005], sz[500005], vis[500005]; long long f[500005], ans[500005], a[500005], b[500005]; long long qpow(long long a, int b = mod - 2) { long long rtv = 1; for (a %= mod; b; b >>= 1, a = a * a % mod) if (b & 1) rtv = rtv * a % mod; return rtv; } inline void ntt1(long long* f, int n) { for (register int i=1;i<n;i+=2) rev[i-1]=rev[i]=rev[i>>1]>>1,rev[i]|=n>>1; for (register int i = 0; i < n; ++i) if (i < rev[i]) swap(f[i], f[rev[i]]); for (register int i = 1; i < n; i <<= 1) { long long w = qpow(3, mod / (i << 1)); for (register int j = 0; j < n; j += i << 1) { long long o = 1; for (register int k = 0; k < i; ++k, o = o * w % mod) { long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod; f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod; } } } return; } inline void ntt2(long long* f, int n) { for (register int i = 0; i < n; ++i) if (i < rev[i]) swap(f[i], f[rev[i]]); for (register int i = 1; i < n; i <<= 1) { long long w = qpow(inv3, mod / (i << 1)); for (register int j = 0; j < n; j += i << 1) { long long o = 1; for (register int k = 0; k < i; ++k, o = o * w % mod) { long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod; f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod; } } } long long _ = qpow(n); for (register int i = 0; i < n; ++i) f[i] = f[i] * _ % mod; return; } inline void ntt3(long long* f, long long* g, int n) { for (register int i=1;i<n;i+=2) rev[i-1]=rev[i]=rev[i>>1]>>1,rev[i]|=n>>1; for (register int i = 0; i < n; ++i) if (i < rev[i]) swap(f[i], f[rev[i]]), swap(g[i], g[rev[i]]); for (register int i = 1; i < n; i <<= 1) { long long w = qpow(3, mod / (i << 1)); for (register int j = 0; j < n; j += i << 1) { long long o = 1; for (register int k = 0; k < i; ++k, o = o * w % mod) { long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod; f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod; tmp1 = g[j + k], tmp2 = g[i + j + k] * o % mod; g[j+k]=(tmp1+tmp2)%mod,g[i+j+k]=(tmp1-tmp2+mod)%mod; } } } return; } int deg[500005]; queue<int> q; inline void bfs(void) { while (!q.empty()) q.pop(); for (register int i = 1; i <= n; ++i) if ((deg[i] = ve[i].size()) == 1) q.push(i); while (!q.empty()) { for (register int i : ve[q.front()]) if (--deg[i] == 1) q.push(i); q.pop(); } int s = 1, cur, lst = 0; while (deg[s] < 2) ++s; cur = s; while (1) { cy[cy[500003]] = cur, ++cy[500003]; for (register int i : ve[cur]) { if (i ^ lst && deg[i] > 1) { lst = cur, cur = i; break; } } if (cur == s) break; } return; } int qu[500005], _h, _t; inline int grt(int s) { qu[_h = _t = 1] = s, dep[s] = fa[s] = 0; while (_h <= _t) { int cur = qu[_h]; sz[cur] = 1; for (register int i : ve[cur]) if (!vis[i] && i ^ fa[cur]) dep[i] = dep[cur] + 1, fa[qu[++_t] = i] = cur; ++_h; } for (register int i = _t; i; --i) { if (sz[qu[i]] >= _t + 1 >> 1) return qu[i]; sz[fa[qu[i]]] += sz[qu[i]]; } } inline void calc1(int root, int fact, int dis) { grt(root); for (register int i = 1; i <= _t; ++i) ++f[dep[qu[i]] += dis]; int len = 1; while (len <= _t << 1) len <<= 1; ntt1(f, len); for (register int i = 0; i < len; ++i) f[i] = f[i] * f[i] % mod; ntt2(f, len); for (register int i = 1; i <= _t; ++i) --f[dep[qu[i]] << 1]; fact = 1LL * fact * inv2 % mod; for (int i=0;i<=min(n,_t<<1);++i) ans[i+1]=(ans[i+1]+f[i]*fact%mod)%mod; for (register int i = 0; i < len; ++i) f[i] = 0; return; } void treedp(int root) { vis[root = grt(root)] = 1; calc1(root, 1, 0); for (register int i : ve[root]) if (!vis[i]) calc1(i, -1, 1); for (register int i : ve[root]) if (!vis[i]) treedp(i); return; } inline void gdp(void) { for (register int i = 1; i <= n; ++i) vis[i] = 0; for (register int i = 0; i < cy[500003]; ++i) { vis[cy[(i+cy[500003]-1)%cy[500003]]]=vis[cy[(i+1)%cy[500003]]]=1; vis[cy[i]] = 0, treedp(cy[i]); } return; } inline void mul(int len) { ntt3(a, b, len); for (register int i = 0; i < len; ++i) a[i] = a[i] * b[i] % mod; ntt2(a, len); return; } inline void calc2(int l, int r, int ql, int qr) { if (ql > qr) return; int n = 0, m = 0; for (register int i = l; i <= r; ++i) for (register int j = 0; j < cnt[i].size(); ++j) { n = max(n, r - i + 1 + j), a[r - i + 1 + j] += cnt[i][j]; } for (register int i = ql; i <= qr; ++i) for (register int j = 0; j < cnt[i].size(); ++j) { m = max(m, i - ql + 1 + j), b[i - ql + 1 + j] += cnt[i][j]; } int len = 1; while (len <= n + m) len <<= 1; mul(len); for (register int i=0;i<=n+m;++i) ans[i+ql-r-1]=(ans[i+ql-r-1]+a[i])%mod; for (register int i = 0; i < len; ++i) a[i] = b[i] = 0; return; } inline void calc3(int l, int r, int ql, int qr) { if (ql > qr) return; int n = 0, m = 0; for (register int i = l; i <= r; ++i) for (register int j = 0; j < cnt[i].size(); ++j) { n = max(n, i - l + 1 + j), a[i - l + 1 + j] += cnt[i][j]; } for (register int i = ql; i <= qr; ++i) for (register int j = 0; j < cnt[i].size(); ++j) { m = max(m, qr - i + 1 + j), b[qr - i + 1 + j] += cnt[i][j]; } int len = 1; while (len <= n + m) len <<= 1; mul(len); for (register int i = 0; i <= n + m; ++i) ans[i+l+cy[500003]-1-qr]=(ans[i+l+cy[500003]-1-qr]+a[i])%mod; for (register int i = 0; i < len; ++i) a[i] = b[i] = 0; return; } void ringdp(int l, int r) { if (l == r) return; int mid = l + r >> 1; calc2(mid + 1, r, min(cy[500003] - 1, max(r, l + cy[500003] / 2)) + 1, min(cy[500003] - 1, mid + cy[500003] / 2 + 1)); calc2(l, mid, mid + 1, min(r, l + cy[500003] / 2)); calc3(l,mid,mid+cy[500003]/2+1,min(cy[500003]-1,r+cy[500003]/2)); ringdp(l, mid), ringdp(mid + 1, r); return; } void sol(void) { for (register int i = 1; i <= n; ++i) vis[i] = 0; for (register int i = 0; i < cy[500003]; ++i) { vis[cy[(i+cy[500003]-1)%cy[500003]]]=vis[cy[(i+1)%cy[500003]]]=1; vis[cy[i]] = 0, grt(cy[i]), cnt[i].resize(_t); for (register int j = 0; j < _t; ++j) cnt[i][j] = 0; for (register int j = 1; j <= _t; ++j) ++cnt[i][dep[qu[j]]]; } ringdp(0, cy[500003] - 1); return; } int main() { int k, m, u, v; scanf("%d%d", &n, &k), m = n; for (register int i = 1; i <= n; ++i) ve[i].clear(); for (register int i = 1; i <= n; ++i) { scanf("%d%d", &u, &v); if (u ^ v) ve[u].push_back(v), ve[v].push_back(u); else --m; } if (m == n) { bfs(); for (register int i = 1; i <= n; ++i) ans[i] = 0; gdp(), sol(); } else treedp(1); for (register int i=1;i<=n;++i) ans[0]=(ans[0]+qpow(i,k)*ans[i+1]%mod)%mod; printf("%lld", (ans[0] * qpow(1LL * n * (n - 1) >> 1) % mod + mod) % mod); return 0; }