题目链接

简单环

题目描述

给定一张 个顶点和 条边的无向图。一个简单环(Simple Cycle)是指顶点序列 () 满足:

  1. 均在图中存在(下标对 取模)。
  2. 顶点 均不重复出现。

现给定正整数 ,请统计图中长度大于 2(即长度 )的简单环数量,并按环长度对 取模后的余数分类,统计出各个长度类的简单环总数对 取模的结果。

解题思路

这是一个在图中寻找并计数特定长度简单环的问题。由于顶点数 非常小(),这强烈暗示了需要使用一个指数级复杂度的算法,状态压缩动态规划 (DP on Subsets) 是解决此类问题的经典方法。

为了避免对同一个环重复计数(例如,从不同节点开始或沿不同方向遍历),我们采用以下策略:

  1. 固定起点:对于任何一个环,我们只在环中编号最小的顶点处统计它。
  2. 固定方向:我们只计算从起点出发,经过的中间节点编号都严格大于起点的路径。这样可以避免 s -> u -> v -> ss -> v -> u -> s 被重复计算。最后将结果除以 2 即可得到无向环的数量。

算法步骤

  1. DP 状态定义: 我们定义 dp[mask][u] 为:从 mask 中编号最小的顶点 s (即 lowbit(mask)) 出发,经过 mask 中所有顶点,最终到达点 u 的简单路径的数量。

  2. DP 初始化: 对于每个顶点 i (从 0n-1),它自身构成一个只包含一个点的路径集合。所以,我们初始化 dp[1 << i][i] = 1

  3. DP 转移: 我们按 mask 从小到大的顺序进行迭代。对于一个给定的 mask 和其中的终点 u,如果 dp[mask][u] > 0,我们可以尝试从 u 走向一个新顶点 v 来扩展路径。

    • s = lowbit_idx(mask): 找到当前路径的起点(即 mask 中最小的顶点)。
    • 遍历 u 的所有邻居 v
    • 如果 v mask 中,并且 v > s(这是为了保证所有中间节点都大于起点,避免重复),我们就可以扩展路径: dp[mask | (1 << v)][v] += dp[mask][u]
  4. 统计环: 在填充 DP 表的过程中,我们就可以统计环的数量。一个从 s 出发,经过 mask,到达 u 的路径,如果 us 之间存在一条边,就构成了一个环。

    • 遍历所有 masku
    • 如果 dp[mask][u] > 0,并且 us = lowbit_idx(mask) 之间有边,那么我们就找到了 dp[mask][u] 个环。
    • 这个环的长度是 len = popcount(mask)
    • 如果 len >= 3,我们将 dp[mask][u] 的值累加到对应长度的环计数器 cycle_counts[len] 上。
  5. 处理结果:

    • 完成 DP 后,cycle_counts[len] 存储了所有长度为 len有向环的数量(且起点是环中最小节点)。为了得到无向环的数量,我们需要将每个计数除以 2。在模意义下,除以 2 等于乘以 2 的模逆元。
    • 创建一个大小为 k 的答案数组 ans
    • 遍历所有可能的环长度 len,从 3n
    • 将无向环的数量 (cycle_counts[len] * inv2) % MOD 累加到最终答案数组 ans[len % k] 中。
    • 最后按顺序输出 ans 数组中的 k 个结果。

这种 的 DP 方法对于 是可以在时限内通过的。

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;

const int MOD = 998244353;

long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

long long modInverse(long long n) {
    return power(n, MOD - 2);
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m, k;
    cin >> n >> m >> k;

    vector<vector<bool>> g(n, vector<bool>(n, false));
    for (int i = 0; i < m; ++i) {
        int u, v;
        cin >> u >> v;
        --u; --v;
        g[u][v] = g[v][u] = true;
    }

    vector<vector<long long>> dp(1 << n, vector<long long>(n, 0));
    for (int i = 0; i < n; ++i) {
        dp[1 << i][i] = 1;
    }

    vector<long long> cycle_counts(n + 1, 0);

    for (int mask = 1; mask < (1 << n); ++mask) {
        int s = __builtin_ctz(mask);
        for (int u = s; u < n; ++u) {
            if ((mask >> u) & 1) { // u is in mask
                if (dp[mask][u] > 0) {
                    for (int v = s + 1; v < n; ++v) {
                        if (!((mask >> v) & 1) && g[u][v]) { // v is not in mask and edge exists
                            int next_mask = mask | (1 << v);
                            dp[next_mask][v] = (dp[next_mask][v] + dp[mask][u]) % MOD;
                        }
                    }
                    int len = __builtin_popcount(mask);
                    if (len >= 3 && g[u][s]) {
                        cycle_counts[len] = (cycle_counts[len] + dp[mask][u]) % MOD;
                    }
                }
            }
        }
    }

    vector<long long> ans(k, 0);
    long long inv2 = modInverse(2);

    for (int len = 3; len <= n; ++len) {
        if (cycle_counts[len] > 0) {
            long long undirected_count = (cycle_counts[len] * inv2) % MOD;
            ans[len % k] = (ans[len % k] + undirected_count) % MOD;
        }
    }

    for (int i = 0; i < k; ++i) {
        cout << ans[i] << "\n";
    }

    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static final int MOD = 998244353;

    public static long power(long base, long exp) {
        long res = 1;
        base %= MOD;
        while (exp > 0) {
            if (exp % 2 == 1) res = (res * base) % MOD;
            base = (base * base) % MOD;
            exp /= 2;
        }
        return res;
    }

    public static long modInverse(long n) {
        return power(n, MOD - 2);
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());
        int k = Integer.parseInt(st.nextToken());

        boolean[][] g = new boolean[n][n];
        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken()) - 1;
            int v = Integer.parseInt(st.nextToken()) - 1;
            g[u][v] = g[v][u] = true;
        }

        long[][] dp = new long[1 << n][n];
        for (int i = 0; i < n; i++) {
            dp[1 << i][i] = 1;
        }

        long[] cycleCounts = new long[n + 1];

        for (int mask = 1; mask < (1 << n); mask++) {
            int s = Integer.numberOfTrailingZeros(mask);
            for (int u = s; u < n; u++) {
                if (((mask >> u) & 1) == 1) { // u is in mask
                    if (dp[mask][u] > 0) {
                        for (int v = s + 1; v < n; v++) {
                            if (((mask >> v) & 1) == 0 && g[u][v]) { // v not in mask and edge exists
                                int nextMask = mask | (1 << v);
                                dp[nextMask][v] = (dp[nextMask][v] + dp[mask][u]) % MOD;
                            }
                        }
                        int len = Integer.bitCount(mask);
                        if (len >= 3 && g[u][s]) {
                            cycleCounts[len] = (cycleCounts[len] + dp[mask][u]) % MOD;
                        }
                    }
                }
            }
        }

        long[] ans = new long[k];
        long inv2 = modInverse(2);

        for (int len = 3; len <= n; len++) {
            if (cycleCounts[len] > 0) {
                long undirectedCount = (cycleCounts[len] * inv2) % MOD;
                ans[len % k] = (ans[len % k] + undirectedCount) % MOD;
            }
        }
        
        PrintWriter out = new PrintWriter(System.out);
        for (int i = 0; i < k; i++) {
            out.println(ans[i]);
        }
        out.flush();
    }
}
import sys

MOD = 998244353

def power(base, exp):
    res = 1
    base %= MOD
    while exp > 0:
        if exp % 2 == 1:
            res = (res * base) % MOD
        base = (base * base) % MOD
        exp //= 2
    return res

def mod_inverse(n):
    return power(n, MOD - 2)

def main():
    try:
        input = sys.stdin.readline
        n, m, k = map(int, input().split())
        
        g = [[False] * n for _ in range(n)]
        for _ in range(m):
            u, v = map(int, input().split())
            u -= 1
            v -= 1
            g[u][v] = g[v][u] = True

        dp = [[0] * n for _ in range(1 << n)]
        for i in range(n):
            dp[1 << i][i] = 1
        
        cycle_counts = [0] * (n + 1)
        
        for mask in range(1, 1 << n):
            # Find the lowest set bit (the starting node s)
            s = (mask & -mask).bit_length() - 1
            for u in range(s, n):
                if (mask >> u) & 1:
                    if dp[mask][u] > 0:
                        # Extend path to a new node v > s
                        for v in range(s + 1, n):
                            if not ((mask >> v) & 1) and g[u][v]:
                                next_mask = mask | (1 << v)
                                dp[next_mask][v] = (dp[next_mask][v] + dp[mask][u]) % MOD
                        
                        # Check for a cycle closing back to s
                        length = bin(mask).count('1')
                        if length >= 3 and g[u][s]:
                            cycle_counts[length] = (cycle_counts[length] + dp[mask][u]) % MOD
        
        ans = [0] * k
        inv2 = mod_inverse(2)
        
        for length in range(3, n + 1):
            if cycle_counts[length] > 0:
                undirected_count = (cycle_counts[length] * inv2) % MOD
                ans[length % k] = (ans[length % k] + undirected_count) % MOD

        for i in range(k):
            sys.stdout.write(str(ans[i]) + '\n')

    except (IOError, ValueError):
        return

main()

算法及复杂度

  • 算法:状态压缩动态规划 (DP on Subsets)
  • 时间复杂度:。DP 状态转移是主要耗时部分,有 mask,每个 mask 下最多遍历 uv
  • 空间复杂度:,用于存储 DP 表。