题目链接

随机路径长度期望

题目描述

给定一个 个点、 条边的有向无环图(DAG)。等概率随机地从图中所有简单路径的集合中选择一条路径。请求出所选路径的长度(经过的边数)的数学期望,结果对 取模。

一条简单路径可以长度为 ,即起点和终点是同一个点。

解题思路

根据数学期望的定义,路径长度的期望等于所有路径的长度总和除以所有路径的总数

因此,问题被分解为两个核心任务:

  1. 计算图中所有简单路径的总数
  2. 计算图中所有简单路径的长度总和

这是一个可以在DAG上通过动态规划解决的问题。我们可以对每个节点 ,计算从它出发的所有路径的相关信息。由于是DAG,我们可以使用记忆化搜索(即DFS+DP)来高效地计算。

DP状态定义

  • count[u]:从节点 出发的所有简单路径的数量。
  • len[u]:从节点 出发的所有简单路径的长度之和

DP转移方程

我们在图上进行深度优先搜索,对于当前节点 ,其DP值的计算依赖于它的所有后继节点

  1. count[u] 的计算: 从节点 出发的路径可以分为两类:

    • 只包含节点 本身的路径(长度为 )。这种路径只有 条。
    • 经过一条边 u -> v 到达邻居 ,然后接上任意一条从 出发的路径。对于每个邻居 ,可以接上 count[v] 条路径。 因此,转移方程为:
  2. len[u] 的计算: 同样,我们分析从 出发的路径长度之和:

    • 长度为 的路径 {u},其长度贡献为
    • 对于每个邻居 ,考虑所有从 u -> v 开始的路径。这样的路径有 count[v] 条。
      • 对于这 count[v] 条路径,边 u -> v 都贡献了 的长度,因此这部分的长度总和为
      • 这些路径的后半部分(从 开始的部分)的长度总和,正好就是 len[v]。 因此,对于一个邻居 ,它对 len[u] 的总贡献是 len[v] + count[v]。累加所有邻居的贡献,得到转移方程:

算法流程

  1. 构建邻接表来表示图。
  2. 使用记忆化搜索(DFS),计算出所有节点 count[u]len[u]。搜索的基线条件是出度为 的节点,对于这样的节点 count[u] = 1len[u] = 0
  3. 遍历所有节点 ,累加得到全局的总路径数 total_paths = sum(count[i]) 和总长度和 total_len = sum(len[i])
  4. 使用快速幂计算 total_paths 的乘法逆元。
  5. 最终答案为 total_len * inv(total_paths) % MOD

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;
using ll = long long;

const int MOD = 998244353;
int n, m;
vector<vector<int>> adj;
vector<ll> count_dp;
vector<ll> len_dp;

ll power(ll base, ll exp) {
    ll res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (__int128)res * base % MOD;
        base = (__int128)base * base % MOD;
        exp /= 2;
    }
    return res;
}

ll modInverse(ll n) {
    return power(n, MOD - 2);
}

void dfs(int u) {
    if (count_dp[u] != -1) {
        return;
    }

    count_dp[u] = 1;
    len_dp[u] = 0;

    for (int v : adj[u]) {
        dfs(v);
        count_dp[u] = (count_dp[u] + count_dp[v]) % MOD;
        len_dp[u] = (len_dp[u] + len_dp[v] + count_dp[v]) % MOD;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> m;
    adj.resize(n + 1);
    count_dp.assign(n + 1, -1);
    len_dp.assign(n + 1, -1);

    for (int i = 0; i < m; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
    }

    for (int i = 1; i <= n; ++i) {
        if (count_dp[i] == -1) {
            dfs(i);
        }
    }

    ll total_paths = 0;
    ll total_len = 0;
    for (int i = 1; i <= n; ++i) {
        total_paths = (total_paths + count_dp[i]) % MOD;
        total_len = (total_len + len_dp[i]) % MOD;
    }

    ll inv_total_paths = modInverse(total_paths);
    ll ans = (total_len * inv_total_paths) % MOD;
    
    cout << ans << "\n";

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;

public class Main {
    static final int MOD = 998244353;
    static int n, m;
    static List<Integer>[] adj;
    static long[] count_dp;
    static long[] len_dp;

    static long power(long base, long exp) {
        long res = 1;
        base %= MOD;
        while (exp > 0) {
            if (exp % 2 == 1) res = (res * base) % MOD;
            base = (base * base) % MOD;
            exp /= 2;
        }
        return res;
    }

    static long modInverse(long n) {
        return power(n, MOD - 2);
    }

    static void dfs(int u) {
        if (count_dp[u] != -1) {
            return;
        }

        count_dp[u] = 1;
        len_dp[u] = 0;

        for (int v : adj[u]) {
            dfs(v);
            count_dp[u] = (count_dp[u] + count_dp[v]) % MOD;
            len_dp[u] = (len_dp[u] + len_dp[v] + count_dp[v]) % MOD;
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        m = sc.nextInt();
        
        adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();
        
        count_dp = new long[n + 1];
        len_dp = new long[n + 1];
        for(int i = 0; i <= n; i++){
            count_dp[i] = -1;
            len_dp[i] = -1;
        }

        for (int i = 0; i < m; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj[u].add(v);
        }

        for (int i = 1; i <= n; i++) {
            if (count_dp[i] == -1) {
                dfs(i);
            }
        }

        long total_paths = 0;
        long total_len = 0;
        for (int i = 1; i <= n; i++) {
            total_paths = (total_paths + count_dp[i]) % MOD;
            total_len = (total_len + len_dp[i]) % MOD;
        }

        long inv_total_paths = modInverse(total_paths);
        long ans = (total_len * inv_total_paths) % MOD;

        System.out.println(ans);
    }
}
import sys

# 增加递归深度限制
sys.setrecursionlimit(200005)

MOD = 998244353
n, m = 0, 0
adj = []
count_dp = []
len_dp = []

def power(base, exp):
    res = 1
    base %= MOD
    while exp > 0:
        if exp % 2 == 1:
            res = (res * base) % MOD
        base = (base * base) % MOD
        exp //= 2
    return res

def mod_inverse(n):
    return power(n, MOD - 2)

def dfs(u):
    if count_dp[u] != -1:
        return

    count_dp[u] = 1
    len_dp[u] = 0

    for v in adj[u]:
        dfs(v)
        count_dp[u] = (count_dp[u] + count_dp[v]) % MOD
        len_dp[u] = (len_dp[u] + len_dp[v] + count_dp[v]) % MOD

def main():
    global n, m, adj, count_dp, len_dp
    n, m = map(int, input().split())
    
    adj = [[] for _ in range(n + 1)]
    count_dp = [-1] * (n + 1)
    len_dp = [-1] * (n + 1)

    for _ in range(m):
        u, v = map(int, input().split())
        adj[u].append(v)
    
    for i in range(1, n + 1):
        if count_dp[i] == -1:
            dfs(i)

    total_paths = 0
    total_len = 0
    for i in range(1, n + 1):
        total_paths = (total_paths + count_dp[i]) % MOD
        total_len = (total_len + len_dp[i]) % MOD

    inv_total_paths = mod_inverse(total_paths)
    ans = (total_len * inv_total_paths) % MOD
    
    print(ans)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:有向无环图上的动态规划(记忆化搜索)
  • 时间复杂度:,其中 是点数, 是边数。每个节点和每条边在DFS过程中只会被访问一次。
  • 空间复杂度:,用于存储图的邻接表、DP数组以及递归栈。