直接求不好求,我们考虑 \(min-max\) 容斥:\(\displaystyle E(max(S))=\sum_{T \subseteq S}(-1)^{|T|+1}E(min(T))\)
其中 \(S\) 为到达相应的点花费时间的集合, \(max(S)\) 为到过所有点的时间, \(min(S)\) 为到过一个点的时间。
然后就变成了给定一个集合 \(S\) ,求 \(min(S)\) .
我们考虑 \(DP\) ,设 \(f[i]\) 为从 \(i\) 点开始,到过 \(S\) 中一个点的期望时间。
若 \(i \in S\) ,则 \(f[i]=0\)
否则 \(\displaystyle f[i]=\frac{f[Fa[i]]+\sum f[son]}{du[i]}+1\)
这时我们可以暴力高斯消元了,但这里有个小技巧:树上路径期望问题可以把每个节点的 \(dp\) 值表示 \(a \times f[Fa[i]] + b\)的形式
然后就可以化简一下式子。
\(\displaystyle f[i]=\frac{f[Fa[i]]+\sum f[son]}{du[i]}+1\)
\(\displaystyle f[i]=\frac{f[Fa[i]]+suma \times f[i]+sumb}{du[i]}+1\)
其中 \(\displaystyle suma=\sum a[son] \space\space\space\space\space\space sumb=\sum b[son]\)
\(\displaystyle (du[i]-suma)f[i]=f[Fa[i]]+sumb+du[i]\)
\(\displaystyle f[i]=\frac{1}{du[i]-suma}f[Fa[i]]+\frac{sumb+du[i]}{du[i]-suma}\)
对比 \(a \times f[Fa[i]] + b\)
可得
\(\displaystyle a[i]=\frac{1}{du[i]-suma}\)和\(\displaystyle b[i]=\frac{sumb+du[i]}{du[i]-suma}\)
我们就求得了\(E(min(T))\),用 \(FWT\) 快速求子集和即可
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int n, q, root, all, x, y, tot, k, s;
const int N = 19, mod = 998244353;
int head[N], to[N << 1], nt[N << 1], du[N], A[N], B[N], f[1 << 18 | 1];
void add(int f, int t)
{
to[++tot] = t; nt[tot] = head[f]; head[f] = tot;
}
LL ksm(LL a, LL b, LL mod)
{
LL res = 1; a %= mod;
for (; b; b >>= 1, a = a * a % mod)
if (b & 1)res = res * a % mod;
return res;
}
void dfs(int x, int fa, int s)
{
if (s & (1 << (x - 1)))return;
int sumA = 0, sumB = 0;
for (int i = head[x]; i; i = nt[i])
if (to[i] != fa)
{
dfs(to[i], x, s);
(sumA += A[to[i]]) %= mod;
(sumB += B[to[i]]) %= mod;
}
int inv = ksm(du[x] - sumA, mod - 2, mod);
A[x] = inv; B[x] = (LL)inv * (sumB + du[x]) % mod;
}
int pan(int x)
{
int res = 0;
while (x)res += (x & 1), x >>= 1;
return res & 1 ? 1 : -1;
}
int main()
{
cin >> n >> q >> root; all = 1 << n;
for (int i = 1; i < n; ++i)
{
scanf("%d%d", &x, &y);
add(x, y); add(y, x); ++du[x]; ++du[y];
}
for (int s = 1; s < all; ++s)
{
for (int i = 1; i <= n; ++i)A[i] = B[i] = 0;
dfs(root, 0, s);
f[s] = (pan(s) * B[root] + mod) % mod;
}
for (int mid = 1; mid < all; mid <<= 1)
for (int j = 0, len = mid << 1; j < all; j += len)
for (int k = j; k < j + mid; ++k)
(f[k + mid] += f[k]) %= mod;
while (q--)
{
scanf("%d", &k); s = 0;
for (int i = 1; i <= k; ++i)scanf("%d", &x), s |= 1 << (x - 1);
printf("%d\n", f[s]);
}
}