题目链接
题目描述
给定一个 个点、
条边的有向无环图(DAG)。等概率随机地从图中所有简单路径的集合中选择一条路径。请求出所选路径的长度(经过的边数)的数学期望,结果对
取模。
一条简单路径可以长度为 ,即起点和终点是同一个点。
解题思路
根据数学期望的定义,路径长度的期望等于所有路径的长度总和除以所有路径的总数。
因此,问题被分解为两个核心任务:
- 计算图中所有简单路径的总数
。
- 计算图中所有简单路径的长度总和
。
这是一个可以在DAG上通过动态规划解决的问题。我们可以对每个节点 ,计算从它出发的所有路径的相关信息。由于是DAG,我们可以使用记忆化搜索(即DFS+DP)来高效地计算。
DP状态定义
count[u]
:从节点出发的所有简单路径的数量。
len[u]
:从节点出发的所有简单路径的长度之和。
DP转移方程
我们在图上进行深度优先搜索,对于当前节点 ,其DP值的计算依赖于它的所有后继节点
。
-
count[u]
的计算: 从节点出发的路径可以分为两类:
- 只包含节点
本身的路径(长度为
)。这种路径只有
条。
- 经过一条边
u -> v
到达邻居,然后接上任意一条从
出发的路径。对于每个邻居
,可以接上
count[v]
条路径。 因此,转移方程为:
- 只包含节点
-
len[u]
的计算: 同样,我们分析从出发的路径长度之和:
- 长度为
的路径
{u}
,其长度贡献为。
- 对于每个邻居
,考虑所有从
u -> v
开始的路径。这样的路径有count[v]
条。- 对于这
count[v]
条路径,边u -> v
都贡献了的长度,因此这部分的长度总和为
。
- 这些路径的后半部分(从
开始的部分)的长度总和,正好就是
len[v]
。 因此,对于一个邻居,它对
len[u]
的总贡献是len[v] + count[v]
。累加所有邻居的贡献,得到转移方程:
- 对于这
- 长度为
算法流程
- 构建邻接表来表示图。
- 使用记忆化搜索(DFS),计算出所有节点
的
count[u]
和len[u]
。搜索的基线条件是出度为的节点,对于这样的节点
,
count[u] = 1
且len[u] = 0
。 - 遍历所有节点
,累加得到全局的总路径数
total_paths = sum(count[i])
和总长度和total_len = sum(len[i])
。 - 使用快速幂计算
total_paths
对的乘法逆元。
- 最终答案为
total_len * inv(total_paths) % MOD
。
代码
#include <iostream>
#include <vector>
#include <numeric>
using namespace std;
using ll = long long;
const int MOD = 998244353;
int n, m;
vector<vector<int>> adj;
vector<ll> count_dp;
vector<ll> len_dp;
ll power(ll base, ll exp) {
ll res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (__int128)res * base % MOD;
base = (__int128)base * base % MOD;
exp /= 2;
}
return res;
}
ll modInverse(ll n) {
return power(n, MOD - 2);
}
void dfs(int u) {
if (count_dp[u] != -1) {
return;
}
count_dp[u] = 1;
len_dp[u] = 0;
for (int v : adj[u]) {
dfs(v);
count_dp[u] = (count_dp[u] + count_dp[v]) % MOD;
len_dp[u] = (len_dp[u] + len_dp[v] + count_dp[v]) % MOD;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
cin >> n >> m;
adj.resize(n + 1);
count_dp.assign(n + 1, -1);
len_dp.assign(n + 1, -1);
for (int i = 0; i < m; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
}
for (int i = 1; i <= n; ++i) {
if (count_dp[i] == -1) {
dfs(i);
}
}
ll total_paths = 0;
ll total_len = 0;
for (int i = 1; i <= n; ++i) {
total_paths = (total_paths + count_dp[i]) % MOD;
total_len = (total_len + len_dp[i]) % MOD;
}
ll inv_total_paths = modInverse(total_paths);
ll ans = (total_len * inv_total_paths) % MOD;
cout << ans << "\n";
return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;
public class Main {
static final int MOD = 998244353;
static int n, m;
static List<Integer>[] adj;
static long[] count_dp;
static long[] len_dp;
static long power(long base, long exp) {
long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
static long modInverse(long n) {
return power(n, MOD - 2);
}
static void dfs(int u) {
if (count_dp[u] != -1) {
return;
}
count_dp[u] = 1;
len_dp[u] = 0;
for (int v : adj[u]) {
dfs(v);
count_dp[u] = (count_dp[u] + count_dp[v]) % MOD;
len_dp[u] = (len_dp[u] + len_dp[v] + count_dp[v]) % MOD;
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
m = sc.nextInt();
adj = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();
count_dp = new long[n + 1];
len_dp = new long[n + 1];
for(int i = 0; i <= n; i++){
count_dp[i] = -1;
len_dp[i] = -1;
}
for (int i = 0; i < m; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj[u].add(v);
}
for (int i = 1; i <= n; i++) {
if (count_dp[i] == -1) {
dfs(i);
}
}
long total_paths = 0;
long total_len = 0;
for (int i = 1; i <= n; i++) {
total_paths = (total_paths + count_dp[i]) % MOD;
total_len = (total_len + len_dp[i]) % MOD;
}
long inv_total_paths = modInverse(total_paths);
long ans = (total_len * inv_total_paths) % MOD;
System.out.println(ans);
}
}
import sys
# 增加递归深度限制
sys.setrecursionlimit(200005)
MOD = 998244353
n, m = 0, 0
adj = []
count_dp = []
len_dp = []
def power(base, exp):
res = 1
base %= MOD
while exp > 0:
if exp % 2 == 1:
res = (res * base) % MOD
base = (base * base) % MOD
exp //= 2
return res
def mod_inverse(n):
return power(n, MOD - 2)
def dfs(u):
if count_dp[u] != -1:
return
count_dp[u] = 1
len_dp[u] = 0
for v in adj[u]:
dfs(v)
count_dp[u] = (count_dp[u] + count_dp[v]) % MOD
len_dp[u] = (len_dp[u] + len_dp[v] + count_dp[v]) % MOD
def main():
global n, m, adj, count_dp, len_dp
n, m = map(int, input().split())
adj = [[] for _ in range(n + 1)]
count_dp = [-1] * (n + 1)
len_dp = [-1] * (n + 1)
for _ in range(m):
u, v = map(int, input().split())
adj[u].append(v)
for i in range(1, n + 1):
if count_dp[i] == -1:
dfs(i)
total_paths = 0
total_len = 0
for i in range(1, n + 1):
total_paths = (total_paths + count_dp[i]) % MOD
total_len = (total_len + len_dp[i]) % MOD
inv_total_paths = mod_inverse(total_paths)
ans = (total_len * inv_total_paths) % MOD
print(ans)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:有向无环图上的动态规划(记忆化搜索)
- 时间复杂度:
,其中
是点数,
是边数。每个节点和每条边在DFS过程中只会被访问一次。
- 空间复杂度:
,用于存储图的邻接表、DP数组以及递归栈。