C Forest

题意:给定 nn 个点 mm 条带权边的无向图,问从中选出若干条边和全部的点构成的 2m2^m 张子图中,最小生成森林的边权值和。n16n \leq 16m100m \leq 100

解法:首先从小到大的对边进行排序,等边权的按照边编号排序,保证生成森林唯一。

从小到大依次加入每条边考虑贡献。考虑第 ii 条边 (u,v)(u,v) 在哪些子图中会加入到生成森林中:u,vu,v 不连通。考虑补集,设比它大的还有 ll 条边,则答案为 fi1,SgU/S2lf_{i-1,S}g_{U/S}2^l,其中 fi1,Sf_{i-1,S} 表示前 i1i-1 条边构成的所有 2i12^{i-1} 个子图中,u,vu,v 所在连通块为 SS 的子图个数,UU 为全集,gSg_{S} 表示 SS 全部的导出子图个数,统计 SS 中边数目为 xx,则 gS=2xg_S=2^x

考虑如何计算 fi,Sf_{i,S}。首先加入 (u,v)(u,v) 只会对 {u,v}S\{u,v\} \sub SSS 有贡献,有:

fi,S2fi1,S+TSfi1,Tfi1,S/Tf_{i,S} \leftarrow 2f_{i-1,S}+\sum_{T \sub S} f_{i-1,T}f_{i-1,S / T}

其含义为:对于已经连通的 fi1,Sf_{i-1,S},这条边是否出现在子图中是无所谓的;然后枚举构成 SS 的两个子连通块 T,S/TT,S/T,要求 uTu \in TvS/Tv \in S/T,利用这条边进行连通。这一过程可以通过子集卷积进行加速到 O(n22n)\mathcal O(n^22^n),朴素实现为 O(3n)\mathcal O(3^n)

因而总的复杂度为 O(m3n)\mathcal O(m3^n)

#include <bits/stdc++.h>
using namespace std;
const int N = 1 << 16;
const long long mod = 998244353;
long long g[N], f[105][N];
long long th[105];
struct line
{
    int from;
    int to;
    int w;
    bool operator<(const line &b)const
    {
        return w < b.w;
    }
    line(int _from, int _to, int _w)
    {
        from = _from;
        to = _to;
        w = _w;
    }
};
int main()
{
    th[0] = 1;
    for (int i = 1; i <= 100;i++)
        th[i] = th[i - 1] * 2 % mod;
    int n, x;
    scanf("%d", &n);
    int all = (1 << n) - 1;
    vector<line> que;
    for (int i = 0; i < n;i++)
        for (int j = 0; j < n;j++)
        {
            scanf("%d", &x);
            if (x && i < j)
                que.emplace_back(i, j, x);
        }
    sort(que.begin(), que.end());
    for (int i = 0; i < n;i++)
        f[0][1 << i] = 1;
    for (int i = 0; i < 1 << n;i++)
        g[i] = 1;
    long long ans = 0;
    for (int i = 0; i < que.size(); i++)
    {
        long long now = th[i];
        int u = que[i].from, v = que[i].to, w = que[i].w;
        for (int S = 0; S < 1 << n; S++)
            if ((S >> u & 1) && (S >> v & 1))
                now = (now - f[i][S] * g[all ^ S] % mod + mod) % mod;
        ans = (ans + now * w % mod * th[que.size() - i - 1] % mod) % mod;
        for (int S = 0; S < 1 << n;S++)
        {
            f[i + 1][S] = f[i][S];
            if (!(S >> u & 1) || !(S >> v & 1))
                continue;
            f[i + 1][S] = f[i + 1][S] * 2 % mod;
            g[S] = g[S] * 2 % mod;
            for (int T = (S - 1) & S; T > S - T; T = (T - 1) & S)
                if ((T >> u & 1) ^ (T >> v & 1))
                    f[i + 1][S] = (f[i + 1][S] + f[i][T] * f[i][S ^ T] % mod) % mod;
        }
    }
    printf("%lld", ans);
    return 0;
}