一、题目概述

给出一张图,求出 $ans=\sum_{i=1}^n\sum_{j=1}^n[i\ne j]\texttt d^k(i,j)$ 的值,其中 $\texttt d(x,y)$ 表示从 $x$ 到 $y$ 的最短路。
对于所有数据满足 $1\le n\le 10^5,1\le k\le 10^9$,保证给定的图 $G$ 满足题中要求,且不存在重边。
$\texttt{subtask 1:}~ 5\%$,满足 $n\le 1000$ 。
$\texttt{subtask 2:}~10\%$,满足 $k=1$ 。
$\texttt{subtask 3:}~15\%$,满足 $k=2$ 。
$\texttt{subtask 4:}~30\%$,满足 $G$ 中存在一条边 $(u,u)$ 。
$\texttt{subtask 5:}~40\%$,无额外限制。

二、解题思路

算法0

首先肯定可以 $\mathcal O(n^3)$ 跑 $\texttt{Floyd}$ ,但是这个是集训队互测的题,可能是由于集训队大佬们都不屑于写,于是就没有这一档部分分。

算法1

数据满足 $n\le 1000$ 。
考虑分类讨论:
如果这是一棵树的话,那么问题就是求 $ans=\sum_{i=1}^n\sum_{j=1}^n[i\ne j](\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)})^k$ ,直接枚举即可。
如果原图是一棵基环树的话,使用基环树的常规套路,先找出环,并且删去环上的任意一条边,剩下的边按照树的方案做一次。做完之后,我们单独考虑这条边 $(u,v)$ 对最短路的贡献,即有哪些路径可以使用 $s\longrightarrow u\rightarrow v\longrightarrow t$ 作为最短路。于是我们先从 $u$ 和 $v$ 向其它点跑一边最短路,然后 $\mathcal O(n^2)$ 枚举 $s$ 和 $t$ ,然后暴力更新即可。

算法2

数据满足 $k=1$ 。
考虑分类讨论:
如果原图是树,那么问题就是求 $ans=\sum_{i=1}^n\sum_{j=1}^n [i\ne j]\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)}$ ,直接枚举每个点作为路径的端点,和作为 $\texttt{LCA}$ 对答案的贡献即可。
如果原图是基环树,那么我们可以分析一下这张图的 `dfs` 树会长成什么样:
pic1
这就意味着,新加入的边如果对某一条路径有影响,那么这条路径的一定是一端在 $\texttt{v1}$ 的子树之外,另一端在 $\texttt{vn}$ 的子树之内。通过这个性质,我们就可以将基环树上的问题转化为树上的问题求解了。
这样的话,我们还可以顺便解决另外一个部分分:
数据满足 $k=2$ 。
答案变为 $ans=\sum_{i=1}^n\sum_{j=1}^n [i\ne j](\texttt{dep}_i+\texttt{dep}_j-2\texttt{dep}_{\texttt{LCA}(i,j)})^2$ ,将完全平方展开后利用上面的方法维护即可。

算法3

满足 $G$ 中存在一条边 $(u,u)$ ,即原图为树。
考虑到答案中存在乘方操作,不容易计算,于是令 $cnt_i$ 表示树上长度为 $i$ 的路径的个数,那么答案就是 $ans=\sum_{i=1}^ncnt_i\times i^k$。不难发现,$cnt$ 可以使用点分治维护,然后用多项式算法优化复杂度,于是这个部分分的问题就解决了。

算法4

数据满足 $1\le n\le 10^5,1\le k\le 10^9$,保证给定的图 $G$ 满足题中要求,且不存在重边。
如果原图是树的话,直接用上面的算法3就好了。
如果是基环树的话,考虑基环DP,先做环以外的点的树形DP,然后在换上合并即可。

参考代码:(代码格式化 Powered by Libre OJ)
#include <iostream>
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
const int mod = 998244353, inv2 = 499122177, inv3 = 332748118;
vector<int> ve[500005], cnt[500005];
int n, rev[500005], cy[500005], dep[500005], fa[500005], sz[500005], vis[500005];
long long f[500005], ans[500005], a[500005], b[500005];
long long qpow(long long a, int b = mod - 2) {
    long long rtv = 1;
    for (a %= mod; b; b >>= 1, a = a * a % mod)
        if (b & 1)
            rtv = rtv * a % mod;
    return rtv;
}
inline void ntt1(long long* f, int n) {
    for (register int i=1;i<n;i+=2) rev[i-1]=rev[i]=rev[i>>1]>>1,rev[i]|=n>>1;
    for (register int i = 0; i < n; ++i)
        if (i < rev[i])
            swap(f[i], f[rev[i]]);
    for (register int i = 1; i < n; i <<= 1) {
        long long w = qpow(3, mod / (i << 1));
        for (register int j = 0; j < n; j += i << 1) {
            long long o = 1;
            for (register int k = 0; k < i; ++k, o = o * w % mod) {
                long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod;
                f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod;
            }
        }
    }
    return;
}
inline void ntt2(long long* f, int n) {
    for (register int i = 0; i < n; ++i)
        if (i < rev[i])
            swap(f[i], f[rev[i]]);
    for (register int i = 1; i < n; i <<= 1) {
        long long w = qpow(inv3, mod / (i << 1));
        for (register int j = 0; j < n; j += i << 1) {
            long long o = 1;
            for (register int k = 0; k < i; ++k, o = o * w % mod) {
                long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod;
                f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod;
            }
        }
    }
    long long _ = qpow(n);
    for (register int i = 0; i < n; ++i) f[i] = f[i] * _ % mod;
    return;
}
inline void ntt3(long long* f, long long* g, int n) {
    for (register int i=1;i<n;i+=2) rev[i-1]=rev[i]=rev[i>>1]>>1,rev[i]|=n>>1;
    for (register int i = 0; i < n; ++i)
        if (i < rev[i])
            swap(f[i], f[rev[i]]), swap(g[i], g[rev[i]]);
    for (register int i = 1; i < n; i <<= 1) {
        long long w = qpow(3, mod / (i << 1));
        for (register int j = 0; j < n; j += i << 1) {
            long long o = 1;
            for (register int k = 0; k < i; ++k, o = o * w % mod) {
                long long tmp1 = f[j + k], tmp2 = f[i + j + k] * o % mod;
                f[j+k]=(tmp1+tmp2)%mod,f[i+j+k]=(tmp1-tmp2+mod)%mod;
                tmp1 = g[j + k], tmp2 = g[i + j + k] * o % mod;
                g[j+k]=(tmp1+tmp2)%mod,g[i+j+k]=(tmp1-tmp2+mod)%mod;
            }
        }
    }
    return;
}
int deg[500005];
queue<int> q;
inline void bfs(void) {
    while (!q.empty()) q.pop();
    for (register int i = 1; i <= n; ++i)
        if ((deg[i] = ve[i].size()) == 1)
            q.push(i);
    while (!q.empty()) {
        for (register int i : ve[q.front()])
            if (--deg[i] == 1)
                q.push(i);
        q.pop();
    }
    int s = 1, cur, lst = 0;
    while (deg[s] < 2) ++s;
    cur = s;
    while (1) {
        cy[cy[500003]] = cur, ++cy[500003];
        for (register int i : ve[cur]) {
            if (i ^ lst && deg[i] > 1) {
                lst = cur, cur = i;
                break;
            }
        }
        if (cur == s)
            break;
    }
    return;
}
int qu[500005], _h, _t;
inline int grt(int s) {
    qu[_h = _t = 1] = s, dep[s] = fa[s] = 0;
    while (_h <= _t) {
        int cur = qu[_h];
        sz[cur] = 1;
        for (register int i : ve[cur])
            if (!vis[i] && i ^ fa[cur])
                dep[i] = dep[cur] + 1, fa[qu[++_t] = i] = cur;
        ++_h;
    }
    for (register int i = _t; i; --i) {
        if (sz[qu[i]] >= _t + 1 >> 1)
            return qu[i];
        sz[fa[qu[i]]] += sz[qu[i]];
    }
}
inline void calc1(int root, int fact, int dis) {
    grt(root);
    for (register int i = 1; i <= _t; ++i) ++f[dep[qu[i]] += dis];
    int len = 1;
    while (len <= _t << 1) len <<= 1;
    ntt1(f, len);
    for (register int i = 0; i < len; ++i) f[i] = f[i] * f[i] % mod;
    ntt2(f, len);
    for (register int i = 1; i <= _t; ++i) --f[dep[qu[i]] << 1];
    fact = 1LL * fact * inv2 % mod;
    for (int i=0;i<=min(n,_t<<1);++i) ans[i+1]=(ans[i+1]+f[i]*fact%mod)%mod;
    for (register int i = 0; i < len; ++i) f[i] = 0;
    return;
}
void treedp(int root) {
    vis[root = grt(root)] = 1;
    calc1(root, 1, 0);
    for (register int i : ve[root])
        if (!vis[i])
            calc1(i, -1, 1);
    for (register int i : ve[root])
        if (!vis[i])
            treedp(i);
    return;
}
inline void gdp(void) {
    for (register int i = 1; i <= n; ++i) vis[i] = 0;
    for (register int i = 0; i < cy[500003]; ++i) {
        vis[cy[(i+cy[500003]-1)%cy[500003]]]=vis[cy[(i+1)%cy[500003]]]=1;
        vis[cy[i]] = 0, treedp(cy[i]);
    }
    return;
}
inline void mul(int len) {
    ntt3(a, b, len);
    for (register int i = 0; i < len; ++i) a[i] = a[i] * b[i] % mod;
    ntt2(a, len);
    return;
}
inline void calc2(int l, int r, int ql, int qr) {
    if (ql > qr)
        return;
    int n = 0, m = 0;
    for (register int i = l; i <= r; ++i)
        for (register int j = 0; j < cnt[i].size(); ++j) {
            n = max(n, r - i + 1 + j), a[r - i + 1 + j] += cnt[i][j];
        }
    for (register int i = ql; i <= qr; ++i)
        for (register int j = 0; j < cnt[i].size(); ++j) {
            m = max(m, i - ql + 1 + j), b[i - ql + 1 + j] += cnt[i][j];
        }
    int len = 1;
    while (len <= n + m) len <<= 1;
    mul(len);
    for (register int i=0;i<=n+m;++i) ans[i+ql-r-1]=(ans[i+ql-r-1]+a[i])%mod;
    for (register int i = 0; i < len; ++i) a[i] = b[i] = 0;
    return;
}
inline void calc3(int l, int r, int ql, int qr) {
    if (ql > qr)
        return;
    int n = 0, m = 0;
    for (register int i = l; i <= r; ++i)
        for (register int j = 0; j < cnt[i].size(); ++j) {
            n = max(n, i - l + 1 + j), a[i - l + 1 + j] += cnt[i][j];
        }
    for (register int i = ql; i <= qr; ++i)
        for (register int j = 0; j < cnt[i].size(); ++j) {
            m = max(m, qr - i + 1 + j), b[qr - i + 1 + j] += cnt[i][j];
        }
    int len = 1;
    while (len <= n + m) len <<= 1;
    mul(len);
    for (register int i = 0; i <= n + m; ++i)
        ans[i+l+cy[500003]-1-qr]=(ans[i+l+cy[500003]-1-qr]+a[i])%mod;
    for (register int i = 0; i < len; ++i) a[i] = b[i] = 0;
    return;
}
void ringdp(int l, int r) {
    if (l == r)
        return;
    int mid = l + r >> 1;
    calc2(mid + 1, r, min(cy[500003] - 1, max(r, l + cy[500003] / 2)) + 1,
          min(cy[500003] - 1, mid + cy[500003] / 2 + 1));
    calc2(l, mid, mid + 1, min(r, l + cy[500003] / 2));
    calc3(l,mid,mid+cy[500003]/2+1,min(cy[500003]-1,r+cy[500003]/2));
    ringdp(l, mid), ringdp(mid + 1, r);
    return;
}
void sol(void) {
    for (register int i = 1; i <= n; ++i) vis[i] = 0;
    for (register int i = 0; i < cy[500003]; ++i) {
        vis[cy[(i+cy[500003]-1)%cy[500003]]]=vis[cy[(i+1)%cy[500003]]]=1;
        vis[cy[i]] = 0, grt(cy[i]), cnt[i].resize(_t);
        for (register int j = 0; j < _t; ++j) cnt[i][j] = 0;
        for (register int j = 1; j <= _t; ++j) ++cnt[i][dep[qu[j]]];
    }
    ringdp(0, cy[500003] - 1);
    return;
}
int main() {
    int k, m, u, v;
    scanf("%d%d", &n, &k), m = n;
    for (register int i = 1; i <= n; ++i) ve[i].clear();
    for (register int i = 1; i <= n; ++i) {
        scanf("%d%d", &u, &v);
        if (u ^ v)
            ve[u].push_back(v), ve[v].push_back(u);
        else
            --m;
    }
    if (m == n) {
        bfs();
        for (register int i = 1; i <= n; ++i) ans[i] = 0;
        gdp(), sol();
    } else
        treedp(1);
    for (register int i=1;i<=n;++i) ans[0]=(ans[0]+qpow(i,k)*ans[i+1]%mod)%mod;
    printf("%lld", (ans[0] * qpow(1LL * n * (n - 1) >> 1) % mod + mod) % mod);
    return 0;
}