题目链接

旺仔哥哥走魔法迷宫

题目描述

在一个 的矩阵中,每个格子都有一个魔力值。从一个起始点开始,每次都等概率地移动到一个魔力值严格小于当前格子的目标格子。每次移动的得分是两点间欧几里得距离的平方。当不存在魔力值更小的格子时,移动停止。求总得分的数学期望。

解题思路

这是一个典型的期望动态规划 (Expectation DP) 问题。由于移动总是从魔力值大的格子走向小的格子,所以移动路径构成一个有向无环图 (DAG),这保证了 DP 的可行性。

DP 状态与顺序

  • 状态定义:设 表示从格子 出发,未来能够获得的总得分的期望值。
  • 计算顺序:要计算格子 的期望值,我们需要知道所有它可能跳往的(即魔力值更小的)格子的期望值。因此,我们应该按照魔力值从小到大的顺序来计算 DP。

DP 转移

设当前格子为 ,其魔力值为 。设所有魔力值小于 的格子的集合为 ,其大小为 。根据期望的线性可加性,从 出发的期望 等于:

这个公式可以拆分为两部分:

前缀和优化

直接计算上述求和式的复杂度很高。我们可以通过展开距离平方公式并使用前缀和来优化。

对所有 求和,得到:

我们会发现,计算 需要的几个值:

  • : 已处理的格子总数
  • : 已处理格子的行坐标之和
  • : 已处理格子的列坐标之和
  • : 已处理格子的坐标平方和
  • : 已处理格子的期望值之和

这些都可以通过维护几个前缀和变量来高效获得。

算法流程

  1. 将所有 个格子存入一个列表,并按魔力值从小到大排序。
  2. 初始化前缀和变量:count = 0, sum_r = 0, sum_c = 0, sum_sq = 0, sum_dp = 0
  3. 遍历排序后的格子列表。分批处理魔力值相同的格子。
  4. 对于一批魔力值相同的格子: a. 使用当前的前缀和变量,为这批中的每一个格子计算出其 dp 值。 b. 计算完毕后,再用这批格子的信息(坐标、坐标平方、计算出的dp值)去更新前缀和变量。
  5. 所有格子计算完毕后,dp[start_r][start_c] 就是最终答案。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;
using ll = long long;

const int MOD = 998244353;

struct Cell {
    int val, r, c;
};

ll power(ll base, ll exp) {
    ll res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (__int128)res * base % MOD;
        base = (__int128)base * base % MOD;
        exp /= 2;
    }
    return res;
}

ll modInverse(ll n) {
    return power(n, MOD - 2);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    vector<Cell> cells;
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            int val;
            cin >> val;
            cells.push_back({val, i, j});
        }
    }

    int start_r, start_c;
    cin >> start_r >> start_c;

    sort(cells.begin(), cells.end(), [](const Cell& a, const Cell& b) {
        return a.val < b.val;
    });

    vector<vector<ll>> dp(n + 1, vector<ll>(m + 1, 0));
    ll count = 0, sum_r = 0, sum_c = 0, sum_sq = 0, sum_dp = 0;

    int i = 0;
    while (i < n * m) {
        int j = i;
        while (j + 1 < n * m && cells[j + 1].val == cells[i].val) {
            j++;
        }

        // 计算当前批次的dp值
        for (int k = i; k <= j; ++k) {
            ll r = cells[k].r;
            ll c = cells[k].c;
            if (count > 0) {
                ll inv_count = modInverse(count);
                ll r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
                
                ll sum_dist_sq = (count * r2_c2) % MOD;
                sum_dist_sq = (sum_dist_sq - (2 * r % MOD * sum_r % MOD) + MOD) % MOD;
                sum_dist_sq = (sum_dist_sq - (2 * c % MOD * sum_c % MOD) + MOD) % MOD;
                sum_dist_sq = (sum_dist_sq + sum_sq) % MOD;
                
                ll term1 = (sum_dist_sq * inv_count) % MOD;
                ll term2 = (sum_dp * inv_count) % MOD;
                dp[r][c] = (term1 + term2) % MOD;
            }
        }

        // 更新前缀和
        for (int k = i; k <= j; ++k) {
            ll r = cells[k].r;
            ll c = cells[k].c;
            count++;
            sum_r = (sum_r + r) % MOD;
            sum_c = (sum_c + c) % MOD;
            ll r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
            sum_sq = (sum_sq + r2_c2) % MOD;
            sum_dp = (sum_dp + dp[r][c]) % MOD;
        }
        i = j + 1;
    }

    cout << dp[start_r][start_c] << endl;

    return 0;
}
import java.util.*;

public class Main {
    static final int MOD = 998244353;

    static class Cell implements Comparable<Cell> {
        int val, r, c;
        Cell(int val, int r, int c) {
            this.val = val;
            this.r = r;
            this.c = c;
        }
        @Override
        public int compareTo(Cell other) {
            return Integer.compare(this.val, other.val);
        }
    }

    static long power(long base, long exp) {
        long res = 1;
        base %= MOD;
        while (exp > 0) {
            if (exp % 2 == 1) res = (res * base) % MOD;
            base = (base * base) % MOD;
            exp /= 2;
        }
        return res;
    }

    static long modInverse(long n) {
        return power(n, MOD - 2);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        List<Cell> cells = new ArrayList<>();
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                cells.add(new Cell(sc.nextInt(), i, j));
            }
        }

        int start_r = sc.nextInt();
        int start_c = sc.nextInt();

        Collections.sort(cells);

        long[][] dp = new long[n + 1][m + 1];
        long count = 0, sum_r = 0, sum_c = 0, sum_sq = 0, sum_dp = 0;

        int i = 0;
        while (i < n * m) {
            int j = i;
            while (j + 1 < n * m && cells.get(j + 1).val == cells.get(i).val) {
                j++;
            }

            for (int k = i; k <= j; k++) {
                long r = cells.get(k).r;
                long c = cells.get(k).c;
                if (count > 0) {
                    long inv_count = modInverse(count);
                    long r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
                    
                    long sum_dist_sq = (count * r2_c2) % MOD;
                    sum_dist_sq = (sum_dist_sq - (2 * r % MOD * sum_r % MOD) + MOD) % MOD;
                    sum_dist_sq = (sum_dist_sq - (2 * c % MOD * sum_c % MOD) + MOD) % MOD;
                    sum_dist_sq = (sum_dist_sq + sum_sq) % MOD;
                    
                    long term1 = (sum_dist_sq * inv_count) % MOD;
                    long term2 = (sum_dp * inv_count) % MOD;
                    dp[(int)r][(int)c] = (term1 + term2) % MOD;
                }
            }

            for (int k = i; k <= j; k++) {
                long r = cells.get(k).r;
                long c = cells.get(k).c;
                count++;
                sum_r = (sum_r + r) % MOD;
                sum_c = (sum_c + c) % MOD;
                long r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
                sum_sq = (sum_sq + r2_c2) % MOD;
                sum_dp = (sum_dp + dp[(int)r][(int)c]) % MOD;
            }
            i = j + 1;
        }

        System.out.println(dp[start_r][start_c]);
    }
}
import sys

MOD = 998244353

def power(base, exp):
    res = 1
    base %= MOD
    while exp > 0:
        if exp % 2 == 1:
            res = (res * base) % MOD
        base = (base * base) % MOD
        exp //= 2
    return res

def mod_inverse(n):
    return power(n, MOD - 2)

def main():
    input = sys.stdin.readline
    n, m = map(int, input().split())
    
    cells = []
    for i in range(1, n + 1):
        row = list(map(int, input().split()))
        for j in range(1, m + 1):
            cells.append((row[j-1], i, j))
    
    start_r, start_c = map(int, input().split())
    
    cells.sort()
    
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    count = 0
    sum_r, sum_c, sum_sq, sum_dp = 0, 0, 0, 0
    
    i = 0
    while i < n * m:
        j = i
        while j + 1 < n * m and cells[j + 1][0] == cells[i][0]:
            j += 1
            
        # Calculate DP for the current batch
        for k in range(i, j + 1):
            val, r, c = cells[k]
            if count > 0:
                inv_count = mod_inverse(count)
                r2_c2 = (r * r + c * c) % MOD
                
                sum_dist_sq = (count * r2_c2) % MOD
                sum_dist_sq = (sum_dist_sq - (2 * r * sum_r) % MOD + MOD) % MOD
                sum_dist_sq = (sum_dist_sq - (2 * c * sum_c) % MOD + MOD) % MOD
                sum_dist_sq = (sum_dist_sq + sum_sq) % MOD
                
                term1 = (sum_dist_sq * inv_count) % MOD
                term2 = (sum_dp * inv_count) % MOD
                dp[r][c] = (term1 + term2) % MOD

        # Update prefix sums with the current batch
        for k in range(i, j + 1):
            val, r, c = cells[k]
            count += 1
            sum_r = (sum_r + r) % MOD
            sum_c = (sum_c + c) % MOD
            r2_c2 = (r * r + c * c) % MOD
            sum_sq = (sum_sq + r2_c2) % MOD
            sum_dp = (sum_dp + dp[r][c]) % MOD
        
        i = j + 1
        
    print(dp[start_r][start_c])

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:期望动态规划 + 前缀和优化
  • 时间复杂度:,瓶颈在于对所有格子进行排序。排序后,计算DP和前缀和的遍历过程是线性的
  • 空间复杂度:,用于存储所有格子的信息和DP表。