题目链接
题目描述
在一个 的矩阵中,每个格子都有一个魔力值。从一个起始点开始,每次都等概率地移动到一个魔力值严格小于当前格子的目标格子。每次移动的得分是两点间欧几里得距离的平方。当不存在魔力值更小的格子时,移动停止。求总得分的数学期望。
解题思路
这是一个典型的期望动态规划 (Expectation DP) 问题。由于移动总是从魔力值大的格子走向小的格子,所以移动路径构成一个有向无环图 (DAG),这保证了 DP 的可行性。
DP 状态与顺序
- 状态定义:设
表示从格子
出发,未来能够获得的总得分的期望值。
- 计算顺序:要计算格子
的期望值,我们需要知道所有它可能跳往的(即魔力值更小的)格子的期望值。因此,我们应该按照魔力值从小到大的顺序来计算 DP。
DP 转移
设当前格子为 ,其魔力值为
。设所有魔力值小于
的格子的集合为
,其大小为
。根据期望的线性可加性,从
出发的期望
等于:
这个公式可以拆分为两部分:
前缀和优化
直接计算上述求和式的复杂度很高。我们可以通过展开距离平方公式并使用前缀和来优化。
对所有 求和,得到:
我们会发现,计算 需要的几个值:
: 已处理的格子总数
: 已处理格子的行坐标之和
: 已处理格子的列坐标之和
: 已处理格子的坐标平方和
: 已处理格子的期望值之和
这些都可以通过维护几个前缀和变量来高效获得。
算法流程
- 将所有
个格子存入一个列表,并按魔力值从小到大排序。
- 初始化前缀和变量:
count = 0
,sum_r = 0
,sum_c = 0
,sum_sq = 0
,sum_dp = 0
。 - 遍历排序后的格子列表。分批处理魔力值相同的格子。
- 对于一批魔力值相同的格子:
a. 使用当前的前缀和变量,为这批中的每一个格子计算出其
dp
值。 b. 计算完毕后,再用这批格子的信息(坐标、坐标平方、计算出的dp值)去更新前缀和变量。 - 所有格子计算完毕后,
dp[start_r][start_c]
就是最终答案。
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
using ll = long long;
const int MOD = 998244353;
struct Cell {
int val, r, c;
};
ll power(ll base, ll exp) {
ll res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (__int128)res * base % MOD;
base = (__int128)base * base % MOD;
exp /= 2;
}
return res;
}
ll modInverse(ll n) {
return power(n, MOD - 2);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
vector<Cell> cells;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
int val;
cin >> val;
cells.push_back({val, i, j});
}
}
int start_r, start_c;
cin >> start_r >> start_c;
sort(cells.begin(), cells.end(), [](const Cell& a, const Cell& b) {
return a.val < b.val;
});
vector<vector<ll>> dp(n + 1, vector<ll>(m + 1, 0));
ll count = 0, sum_r = 0, sum_c = 0, sum_sq = 0, sum_dp = 0;
int i = 0;
while (i < n * m) {
int j = i;
while (j + 1 < n * m && cells[j + 1].val == cells[i].val) {
j++;
}
// 计算当前批次的dp值
for (int k = i; k <= j; ++k) {
ll r = cells[k].r;
ll c = cells[k].c;
if (count > 0) {
ll inv_count = modInverse(count);
ll r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
ll sum_dist_sq = (count * r2_c2) % MOD;
sum_dist_sq = (sum_dist_sq - (2 * r % MOD * sum_r % MOD) + MOD) % MOD;
sum_dist_sq = (sum_dist_sq - (2 * c % MOD * sum_c % MOD) + MOD) % MOD;
sum_dist_sq = (sum_dist_sq + sum_sq) % MOD;
ll term1 = (sum_dist_sq * inv_count) % MOD;
ll term2 = (sum_dp * inv_count) % MOD;
dp[r][c] = (term1 + term2) % MOD;
}
}
// 更新前缀和
for (int k = i; k <= j; ++k) {
ll r = cells[k].r;
ll c = cells[k].c;
count++;
sum_r = (sum_r + r) % MOD;
sum_c = (sum_c + c) % MOD;
ll r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
sum_sq = (sum_sq + r2_c2) % MOD;
sum_dp = (sum_dp + dp[r][c]) % MOD;
}
i = j + 1;
}
cout << dp[start_r][start_c] << endl;
return 0;
}
import java.util.*;
public class Main {
static final int MOD = 998244353;
static class Cell implements Comparable<Cell> {
int val, r, c;
Cell(int val, int r, int c) {
this.val = val;
this.r = r;
this.c = c;
}
@Override
public int compareTo(Cell other) {
return Integer.compare(this.val, other.val);
}
}
static long power(long base, long exp) {
long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
static long modInverse(long n) {
return power(n, MOD - 2);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
List<Cell> cells = new ArrayList<>();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
cells.add(new Cell(sc.nextInt(), i, j));
}
}
int start_r = sc.nextInt();
int start_c = sc.nextInt();
Collections.sort(cells);
long[][] dp = new long[n + 1][m + 1];
long count = 0, sum_r = 0, sum_c = 0, sum_sq = 0, sum_dp = 0;
int i = 0;
while (i < n * m) {
int j = i;
while (j + 1 < n * m && cells.get(j + 1).val == cells.get(i).val) {
j++;
}
for (int k = i; k <= j; k++) {
long r = cells.get(k).r;
long c = cells.get(k).c;
if (count > 0) {
long inv_count = modInverse(count);
long r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
long sum_dist_sq = (count * r2_c2) % MOD;
sum_dist_sq = (sum_dist_sq - (2 * r % MOD * sum_r % MOD) + MOD) % MOD;
sum_dist_sq = (sum_dist_sq - (2 * c % MOD * sum_c % MOD) + MOD) % MOD;
sum_dist_sq = (sum_dist_sq + sum_sq) % MOD;
long term1 = (sum_dist_sq * inv_count) % MOD;
long term2 = (sum_dp * inv_count) % MOD;
dp[(int)r][(int)c] = (term1 + term2) % MOD;
}
}
for (int k = i; k <= j; k++) {
long r = cells.get(k).r;
long c = cells.get(k).c;
count++;
sum_r = (sum_r + r) % MOD;
sum_c = (sum_c + c) % MOD;
long r2_c2 = (r * r % MOD + c * c % MOD) % MOD;
sum_sq = (sum_sq + r2_c2) % MOD;
sum_dp = (sum_dp + dp[(int)r][(int)c]) % MOD;
}
i = j + 1;
}
System.out.println(dp[start_r][start_c]);
}
}
import sys
MOD = 998244353
def power(base, exp):
res = 1
base %= MOD
while exp > 0:
if exp % 2 == 1:
res = (res * base) % MOD
base = (base * base) % MOD
exp //= 2
return res
def mod_inverse(n):
return power(n, MOD - 2)
def main():
input = sys.stdin.readline
n, m = map(int, input().split())
cells = []
for i in range(1, n + 1):
row = list(map(int, input().split()))
for j in range(1, m + 1):
cells.append((row[j-1], i, j))
start_r, start_c = map(int, input().split())
cells.sort()
dp = [[0] * (m + 1) for _ in range(n + 1)]
count = 0
sum_r, sum_c, sum_sq, sum_dp = 0, 0, 0, 0
i = 0
while i < n * m:
j = i
while j + 1 < n * m and cells[j + 1][0] == cells[i][0]:
j += 1
# Calculate DP for the current batch
for k in range(i, j + 1):
val, r, c = cells[k]
if count > 0:
inv_count = mod_inverse(count)
r2_c2 = (r * r + c * c) % MOD
sum_dist_sq = (count * r2_c2) % MOD
sum_dist_sq = (sum_dist_sq - (2 * r * sum_r) % MOD + MOD) % MOD
sum_dist_sq = (sum_dist_sq - (2 * c * sum_c) % MOD + MOD) % MOD
sum_dist_sq = (sum_dist_sq + sum_sq) % MOD
term1 = (sum_dist_sq * inv_count) % MOD
term2 = (sum_dp * inv_count) % MOD
dp[r][c] = (term1 + term2) % MOD
# Update prefix sums with the current batch
for k in range(i, j + 1):
val, r, c = cells[k]
count += 1
sum_r = (sum_r + r) % MOD
sum_c = (sum_c + c) % MOD
r2_c2 = (r * r + c * c) % MOD
sum_sq = (sum_sq + r2_c2) % MOD
sum_dp = (sum_dp + dp[r][c]) % MOD
i = j + 1
print(dp[start_r][start_c])
if __name__ == "__main__":
main()
算法及复杂度
- 算法:期望动态规划 + 前缀和优化
- 时间复杂度:
,瓶颈在于对所有格子进行排序。排序后,计算DP和前缀和的遍历过程是线性的
。
- 空间复杂度:
,用于存储所有格子的信息和DP表。