题目链接
题目描述
给定一张 个顶点和
条边的无向图。一个简单环(Simple Cycle)是指顶点序列
(
) 满足:
- 边
均在图中存在(下标对
取模)。
- 顶点
均不重复出现。
现给定正整数 ,请统计图中长度大于 2(即长度
)的简单环数量,并按环长度对
取模后的余数分类,统计出各个长度类的简单环总数对
取模的结果。
解题思路
这是一个在图中寻找并计数特定长度简单环的问题。由于顶点数 非常小(
),这强烈暗示了需要使用一个指数级复杂度的算法,状态压缩动态规划 (DP on Subsets) 是解决此类问题的经典方法。
为了避免对同一个环重复计数(例如,从不同节点开始或沿不同方向遍历),我们采用以下策略:
- 固定起点:对于任何一个环,我们只在环中编号最小的顶点处统计它。
- 固定方向:我们只计算从起点出发,经过的中间节点编号都严格大于起点的路径。这样可以避免
s -> u -> v -> s
和s -> v -> u -> s
被重复计算。最后将结果除以 2 即可得到无向环的数量。
算法步骤
-
DP 状态定义: 我们定义
dp[mask][u]
为:从mask
中编号最小的顶点s
(即lowbit(mask)
) 出发,经过mask
中所有顶点,最终到达点u
的简单路径的数量。 -
DP 初始化: 对于每个顶点
i
(从0
到n-1
),它自身构成一个只包含一个点的路径集合。所以,我们初始化dp[1 << i][i] = 1
。 -
DP 转移: 我们按
mask
从小到大的顺序进行迭代。对于一个给定的mask
和其中的终点u
,如果dp[mask][u] > 0
,我们可以尝试从u
走向一个新顶点v
来扩展路径。s = lowbit_idx(mask)
: 找到当前路径的起点(即mask
中最小的顶点)。- 遍历
u
的所有邻居v
。 - 如果
v
不在mask
中,并且v > s
(这是为了保证所有中间节点都大于起点,避免重复),我们就可以扩展路径:dp[mask | (1 << v)][v] += dp[mask][u]
-
统计环: 在填充 DP 表的过程中,我们就可以统计环的数量。一个从
s
出发,经过mask
,到达u
的路径,如果u
和s
之间存在一条边,就构成了一个环。- 遍历所有
mask
和u
。 - 如果
dp[mask][u] > 0
,并且u
和s = lowbit_idx(mask)
之间有边,那么我们就找到了dp[mask][u]
个环。 - 这个环的长度是
len = popcount(mask)
。 - 如果
len >= 3
,我们将dp[mask][u]
的值累加到对应长度的环计数器cycle_counts[len]
上。
- 遍历所有
-
处理结果:
- 完成 DP 后,
cycle_counts[len]
存储了所有长度为len
的有向环的数量(且起点是环中最小节点)。为了得到无向环的数量,我们需要将每个计数除以 2。在模意义下,除以 2 等于乘以 2 的模逆元。 - 创建一个大小为
k
的答案数组ans
。 - 遍历所有可能的环长度
len
,从 3 到n
。 - 将无向环的数量
(cycle_counts[len] * inv2) % MOD
累加到最终答案数组ans[len % k]
中。 - 最后按顺序输出
ans
数组中的k
个结果。
- 完成 DP 后,
这种 的 DP 方法对于
是可以在时限内通过的。
代码
#include <iostream>
#include <vector>
#include <numeric>
using namespace std;
const int MOD = 998244353;
long long power(long long base, long long exp) {
long long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
long long modInverse(long long n) {
return power(n, MOD - 2);
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, m, k;
cin >> n >> m >> k;
vector<vector<bool>> g(n, vector<bool>(n, false));
for (int i = 0; i < m; ++i) {
int u, v;
cin >> u >> v;
--u; --v;
g[u][v] = g[v][u] = true;
}
vector<vector<long long>> dp(1 << n, vector<long long>(n, 0));
for (int i = 0; i < n; ++i) {
dp[1 << i][i] = 1;
}
vector<long long> cycle_counts(n + 1, 0);
for (int mask = 1; mask < (1 << n); ++mask) {
int s = __builtin_ctz(mask);
for (int u = s; u < n; ++u) {
if ((mask >> u) & 1) { // u is in mask
if (dp[mask][u] > 0) {
for (int v = s + 1; v < n; ++v) {
if (!((mask >> v) & 1) && g[u][v]) { // v is not in mask and edge exists
int next_mask = mask | (1 << v);
dp[next_mask][v] = (dp[next_mask][v] + dp[mask][u]) % MOD;
}
}
int len = __builtin_popcount(mask);
if (len >= 3 && g[u][s]) {
cycle_counts[len] = (cycle_counts[len] + dp[mask][u]) % MOD;
}
}
}
}
}
vector<long long> ans(k, 0);
long long inv2 = modInverse(2);
for (int len = 3; len <= n; ++len) {
if (cycle_counts[len] > 0) {
long long undirected_count = (cycle_counts[len] * inv2) % MOD;
ans[len % k] = (ans[len % k] + undirected_count) % MOD;
}
}
for (int i = 0; i < k; ++i) {
cout << ans[i] << "\n";
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static final int MOD = 998244353;
public static long power(long base, long exp) {
long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
public static long modInverse(long n) {
return power(n, MOD - 2);
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
int k = Integer.parseInt(st.nextToken());
boolean[][] g = new boolean[n][n];
for (int i = 0; i < m; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken()) - 1;
int v = Integer.parseInt(st.nextToken()) - 1;
g[u][v] = g[v][u] = true;
}
long[][] dp = new long[1 << n][n];
for (int i = 0; i < n; i++) {
dp[1 << i][i] = 1;
}
long[] cycleCounts = new long[n + 1];
for (int mask = 1; mask < (1 << n); mask++) {
int s = Integer.numberOfTrailingZeros(mask);
for (int u = s; u < n; u++) {
if (((mask >> u) & 1) == 1) { // u is in mask
if (dp[mask][u] > 0) {
for (int v = s + 1; v < n; v++) {
if (((mask >> v) & 1) == 0 && g[u][v]) { // v not in mask and edge exists
int nextMask = mask | (1 << v);
dp[nextMask][v] = (dp[nextMask][v] + dp[mask][u]) % MOD;
}
}
int len = Integer.bitCount(mask);
if (len >= 3 && g[u][s]) {
cycleCounts[len] = (cycleCounts[len] + dp[mask][u]) % MOD;
}
}
}
}
}
long[] ans = new long[k];
long inv2 = modInverse(2);
for (int len = 3; len <= n; len++) {
if (cycleCounts[len] > 0) {
long undirectedCount = (cycleCounts[len] * inv2) % MOD;
ans[len % k] = (ans[len % k] + undirectedCount) % MOD;
}
}
PrintWriter out = new PrintWriter(System.out);
for (int i = 0; i < k; i++) {
out.println(ans[i]);
}
out.flush();
}
}
import sys
MOD = 998244353
def power(base, exp):
res = 1
base %= MOD
while exp > 0:
if exp % 2 == 1:
res = (res * base) % MOD
base = (base * base) % MOD
exp //= 2
return res
def mod_inverse(n):
return power(n, MOD - 2)
def main():
try:
input = sys.stdin.readline
n, m, k = map(int, input().split())
g = [[False] * n for _ in range(n)]
for _ in range(m):
u, v = map(int, input().split())
u -= 1
v -= 1
g[u][v] = g[v][u] = True
dp = [[0] * n for _ in range(1 << n)]
for i in range(n):
dp[1 << i][i] = 1
cycle_counts = [0] * (n + 1)
for mask in range(1, 1 << n):
# Find the lowest set bit (the starting node s)
s = (mask & -mask).bit_length() - 1
for u in range(s, n):
if (mask >> u) & 1:
if dp[mask][u] > 0:
# Extend path to a new node v > s
for v in range(s + 1, n):
if not ((mask >> v) & 1) and g[u][v]:
next_mask = mask | (1 << v)
dp[next_mask][v] = (dp[next_mask][v] + dp[mask][u]) % MOD
# Check for a cycle closing back to s
length = bin(mask).count('1')
if length >= 3 and g[u][s]:
cycle_counts[length] = (cycle_counts[length] + dp[mask][u]) % MOD
ans = [0] * k
inv2 = mod_inverse(2)
for length in range(3, n + 1):
if cycle_counts[length] > 0:
undirected_count = (cycle_counts[length] * inv2) % MOD
ans[length % k] = (ans[length % k] + undirected_count) % MOD
for i in range(k):
sys.stdout.write(str(ans[i]) + '\n')
except (IOError, ValueError):
return
main()
算法及复杂度
- 算法:状态压缩动态规划 (DP on Subsets)
- 时间复杂度:
。DP 状态转移是主要耗时部分,有
个
mask
,每个mask
下最多遍历个
u
和个
v
。 - 空间复杂度:
,用于存储 DP 表。