题目传送门

算法:min-max 容斥、树上背包、NTT。

题意简述

有一棵 个点的树。一开始所有点都是白色,每次操作会随机选择 条路径中的一条,将路径上所有点染黑。求所有点都被染黑的期望操作数。

。多组数据。对 取模。

题解

套路性地,我们使用 min-max 容斥。

如果我们把树画出来,并将 内的点标记,可以发现,这些点把原树分成了若干个连通块,而每个连通块内部可以任意选取路径,能够保证该路径不经过 内的点。而一旦路径跨越连通块,那么一定经过 内的点。

记共分成了 个连通块 ,则不经过 的路径数为

那么概率为

继续化简原式,得

对于最终情况,一个划分方案给答案带来的贡献仅仅是 两块。为了获得这些信息,我们只关心于选点个数为奇数、偶数的方案中,有多少个方案的 为我枚举的常数

考虑设计一个动态规划来解决这个问题。我们先选取 号点作为全树的根,把该树变成一棵有根树。

表示只考虑了以 为根的子树,选了奇数/偶数个点,包含 的连通块大小为 (不含 所在连通块)的值为 的方案数。其中 表示 号点不在连通块中(即在所选的点集中)。

此外,我们记一个数组 表示以 为根的子树,选了奇数/偶数个点,(含 所在连通块)的值为 的方案数。显然有
$$

转移类似树上背包。合并 与一个儿子 时,有转移式:

暴力转移复杂度是 的,并不可接受。

我们留意到 一维的大小是不超过 的子树大小的,这是树上背包的经典形式。因此复杂度就少了一个 ,变成了

继续观察,发现 一维的转移是一个卷积形式,可以使用 NTT 进行优化。换句话说,除了把 算贡献的地方,如果我们把状态的最后一维看成一个 次的多项式,那么这里的所有运算都可以看成多项式加法和多项式乘法。因此我们可以在一开始就用点值表示法表示 的值,只有在 算贡献的时候,我们才进行一次 INTT,把点值表示法还原回系数表示法,得到 后,就又可以还原成点值表示法。

于是原本 的转移被优化到了 。这样以后总复杂度变为

注意常数优化应该是可以通过的。

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>

const int MaxN = 50 + 1;
const int MaxV = 2048 + 5;
const int Mod = 998244353, Prt = 3;

struct Graph {
  int cnte;
  int Head[MaxN], To[MaxN * 2], Next[MaxN * 2];

  inline void clear() {
    cnte = 0;
    memset(Head, 0, sizeof Head);
    memset(To, 0, sizeof To);
    memset(Next, 0, sizeof Next);
  }

  inline void addEdge(int from, int to) {
    cnte++; To[cnte] = to;
    Next[cnte] = Head[from]; Head[from] = cnte;
  }
};

int Te, N;
int Fa[MaxN], Siz[MaxN];
int F[MaxN][2][MaxN][MaxV], G[MaxN][2][MaxV];
int Rev[13][MaxV], W[2][MaxV], Inv[MaxV], Log[MaxV];
Graph Gr;
inline int add(int x, int y) { return (x += y) >= Mod ? x - Mod : x; }
inline int sub(int x, int y) { return (x -= y) < 0 ? x + Mod : x; }
inline int mul(int x, int y) { return 1LL * x * y % Mod; }
inline int pw(int x, int y) { int z = 1; for (; y; y >>= 1, x = mul(x, x)) if (y & 1) z = mul(z, x); return z; }
inline int inv(int x) { return pw(x, Mod - 2); }
inline int sep(int x, int y) { return mul(x, inv(y)); }
inline void inc(int &x, int y = 1) { x = add(x, y); }
inline void dec(int &x, int y = 1) { x = sub(x, y); }

void init() {
  scanf("%d", &N);
  for (int i = 1; i < N; ++i) {
    int u, v;
    scanf("%d %d", &u, &v);
    Gr.addEdge(u, v);
    Gr.addEdge(v, u);
  }
}

void dfs1(int u) {
  Siz[u] = 1;
  for (int i = Gr.Head[u]; i; i = Gr.Next[i]) {
    int v = Gr.To[i];
    if (v == Fa[u]) continue;
    Fa[v] = u;
    dfs1(v);
    Siz[u] += Siz[v];
  }
}

inline void ntt(int *a, int n, int f) {
  for (int i = 1; i < n; ++i)
    if (i < Rev[Log[n]][i]) std::swap(a[i], a[Rev[Log[n]][i]]);
  for (int i = 1; i < n; i <<= 1) {
    int w = W[f][i];
    for (int j = 0; j < n; j += (i << 1)) {
      int x = 1;
      for (int k = 0; k < i; ++k, x = mul(x, w)) {
        int lson = a[j + k], rson = a[i + j + k];
        a[j + k] = add(lson, mul(rson, x));
        a[i + j + k] = sub(lson, mul(rson, x));
      }
    }
  }
  if (f == 1)
    for (int i = 0; i < n; ++i) a[i] = mul(a[i], Inv[n]);
}

inline int getPow2(int n) {
  int v = 1;
  while (v < n) v <<= 1;
  return v;
}

inline void calcG(int u) {
  for (int odd = 0; odd <= 1; ++odd) {
    for (int j = 0; j <= Siz[u]; ++j)
      for (int k = 0; k <= Siz[u] * (Siz[u] + 1) / 2; ++k) {
        int newK = j * (j + 1) / 2 + k;
        if (newK > N * (N + 1) / 2) break;
        inc(G[u][odd][newK], F[u][odd][j][k]);
      }
  }
}

void dfs2(int u) {
  F[u][1][0][0] = F[u][0][1][0] = 1;
  int sz = 1;
  std::vector<int> vec;
  for (int i = Gr.Head[u]; i; i = Gr.Next[i]) {
    int v = Gr.To[i];
    if (v == Fa[u]) continue;
    dfs2(v);
    vec.push_back(v);
  }
  std::sort(vec.begin(), vec.end(), [](int x, int y){ return Siz[x] < Siz[y]; });
  for (int v : vec) {
    sz += Siz[v];
    static int f[2][MaxN][MaxV];
    int len = getPow2((sz - 1) * sz / 2 + 1);
    for (int j = 0; j <= sz; ++j)
      for (int k = 0; k < len; ++k)
        f[0][j][k]  = f[1][j][k] = 0;
    ntt(G[v][0], len, 0); ntt(G[v][1], len, 0);
    for (int j = 0; j <= sz - Siz[v]; ++j) ntt(F[u][0][j], len, 0), ntt(F[u][1][j], len, 0);
    for (int j = 0; j <= Siz[v]; ++j) ntt(F[v][0][j], len, 0), ntt(F[v][1][j], len, 0);

    for (int j = 0; j <= sz; ++j) {
      for (int odd2 = 0; odd2 <= 1; ++odd2) {
        if (j == 0) {
          for (int k = 0; k < len; ++k)
            inc(f[0][j][k], mul(F[u][odd2][j][k], G[v][odd2][k]));
        } else {
          for (int j2 = std::max(0, j - sz + Siz[v]); j2 <= Siz[v] && j2 < j; ++j2) {
            for (int k = 0; k < len; ++k)
              inc(f[0][j][k], mul(F[u][odd2][j - j2][k], F[v][odd2][j2][k]));
          }
        }
      }
    }
    for (int j = 0; j <= sz; ++j) {
      for (int odd2 = 0; odd2 <= 1; ++odd2) {
        if (j == 0) {
          for (int k = 0; k < len; ++k)
            inc(f[1][j][k], mul(F[u][1 ^ odd2][j][k], G[v][odd2][k]));
        } else {
          for (int j2 = std::max(0, j - sz + Siz[v]); j2 <= Siz[v] && j2 < j; ++j2) {
            for (int k = 0; k < len; ++k)
              inc(f[1][j][k], mul(F[u][1 ^ odd2][j - j2][k], F[v][odd2][j2][k]));
          }
        }
      }
    }

    for (int j = 0; j <= sz; ++j) {
      ntt(f[0][j], len, 1);
      for (int k = 0; k < len; ++k)
        F[u][0][j][k] = f[0][j][k];
    }
    for (int j = 0; j <= sz; ++j) {
      ntt(f[1][j], len, 1);
      for (int k = 0; k < len; ++k)
        F[u][1][j][k] = f[1][j][k];
    }
  }
  calcG(u);
}

void solve() {
  dfs1(1);
  dfs2(1);
  int ans = 0;
  for (int odd = 0; odd <= 1; ++odd)
    for (int k = 0; k < N * (N + 1) / 2; ++k) {
      if (odd == 1) inc(ans, mul(G[1][odd][k], inv(N * (N + 1) / 2 - k)));
      else dec(ans, mul(G[1][odd][k], inv(N * (N + 1) / 2 - k)));
    }
  ans = mul(ans, N * (N + 1) / 2);
  printf("%d\n", ans);
}

void clear() {
  memset(Fa, 0, sizeof Fa);
  memset(Siz, 0, sizeof Siz);
  memset(F, 0, sizeof F);
  memset(G, 0, sizeof G);
  Gr.clear();
}

int main() {
  for (int i = 1, l = 0; i <= 2048; i <<= 1, ++l) {
    Log[i] = l;
    Rev[l][0] = 0;
    for (int j = 1; j < i; ++j) {
      Rev[l][j] = (Rev[l][j >> 1]) >> 1;
      if (j & 1) Rev[l][j] |= (i >> 1);
    }
    Inv[i] = inv(i);
  }
  for (int f = 0; f <= 1; ++f)
    for (int i = 1; i < 2048; i <<= 1)
      W[f][i] = pw(pw(Prt, f == 0 ? 1 : Mod - 2), (Mod - 1) / (i << 1));
  scanf("%d", &Te);
  for (int t = 1; t <= Te; ++t) {
    init();
    printf("Case #%d: ", t);
    solve();
    clear();
  }
  return 0;
}