题目链接
题目描述
在一个 的棋盘上,每个格子写有数字 0 或 1。当小球在某个格子上时,如果数字是 0,它会向下移动一格;如果是 1,它会向右移动一格。对于给定的
个查询,每个查询指定一个子矩阵范围和一个起始点,要求计算小球从该起始点出发,在子矩阵内滚动,最终会从哪个格子滚出子矩阵。
解题思路
对于每个查询,一个朴素的方法是根据规则一步步模拟小球的移动,直到它越出子矩阵边界。对于一个 的棋盘,单次模拟最坏情况下可能需要
步。考虑到查询数量
可能很大(高达
),
的总复杂度可能会超时。
为了优化这个过程,我们可以注意到小球的移动路径是完全确定的。从任何一个格子 出发,它的下一个位置是唯一的。这种性质非常适合使用**倍增(Binary Lifting)**算法来加速。
倍增算法的核心思想是预处理出从每个点出发,走 步后会到达的位置,从而将模拟过程从“一步一步走”加速为“大步跳跃”。
1. 预处理
我们创建一个三维数组 jump[k][r][c]
,用于存储从格子 出发,经过
次移动后到达的位置。
-
基础状态 (k=0):
jump[0][r][c]
表示从移动 1 步到达的位置。这可以根据棋盘上
grid[r][c]
的值直接确定:- 如果
grid[r][c] == 0
,则jump[0][r][c] = (r+1, c)
。 - 如果
grid[r][c] == 1
,则jump[0][r][c] = (r, c+1)
。
- 如果
-
递推关系 (k>0): 移动
步等价于先移动
步,到达一个中间位置,然后再从那个中间位置移动
步。因此,递推公式为:
jump[k][r][c] = jump[k-1][ jump[k-1][r][c] ]
整个预处理过程的时间复杂度为 。
2. 查询
对于每个查询,给定子矩阵 和起点
,我们的目标是找到路径上最后一个在子矩阵内的格子。
我们可以从起点开始,尝试进行大步的跳跃。为了确保不错过最终答案,我们应该从最大的步长(最大的 )开始尝试递减地跳跃。
- 设当前位置为
pos = (r, c)
,初始时pos = (r1, c1)
。 - 从
k = max_logN
向下循环到 0:- 计算尝试跳跃
步后的目标位置
next_pos = jump[k][pos.r][pos.c]
。 - 如果
next_pos
仍然在子矩阵的边界内(即next_pos.r <= r2
且next_pos.c <= c2
),说明这次跳跃是安全的,我们可以执行它。更新pos = next_pos
。 - 如果
next_pos
会跳出子矩阵,则放弃这次跳跃,尝试更小的步长(即更小的)。
- 计算尝试跳跃
当循环结束后,pos
就是小球滚出子矩阵前的最后一个位置。每次查询的时间复杂度为 。
代码
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
const int MAXN = 1001;
const int LOGN = 11; // ceil(log2(1001))
pair<int, int> jump[LOGN][MAXN][MAXN];
int grid[MAXN][MAXN];
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
cin >> grid[i][j];
}
}
// Precomputation - Step 1 (k=0)
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
if (grid[i][j] == 0) {
jump[0][i][j] = {i + 1, j};
} else {
jump[0][i][j] = {i, j + 1};
}
}
}
for (int i = 1; i <= n + 1; ++i) {
jump[0][n + 1][i] = {n + 1, i};
jump[0][i][n + 1] = {i, n + 1};
}
// Precomputation - Step 2 (k>0)
for (int k = 1; k < LOGN; ++k) {
for (int i = 1; i <= n + 1; ++i) {
for (int j = 1; j <= n + 1; ++j) {
pair<int, int> mid = jump[k - 1][i][j];
jump[k][i][j] = jump[k - 1][mid.first][mid.second];
}
}
}
int q;
cin >> q;
while (q--) {
int r1, c1, r2, c2;
cin >> r1 >> c1 >> r2 >> c2;
int cur_r = r1, cur_c = c1;
for (int k = LOGN - 1; k >= 0; --k) {
pair<int, int> next_pos = jump[k][cur_r][cur_c];
if (next_pos.first <= r2 && next_pos.second <= c2) {
cur_r = next_pos.first;
cur_c = next_pos.second;
}
}
cout << cur_r << " " << cur_c << "\n";
}
return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.StringTokenizer;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine());
int[][] grid = new int[n + 1][n + 1];
for (int i = 1; i <= n; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
for (int j = 1; j <= n; j++) {
grid[i][j] = Integer.parseInt(st.nextToken());
}
}
int LOGN = (int) Math.ceil(Math.log(n * 2) / Math.log(2)) + 1;
int[][][] jumpR = new int[LOGN][n + 2][n + 2];
int[][][] jumpC = new int[LOGN][n + 2][n + 2];
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
if (grid[i][j] == 0) {
jumpR[0][i][j] = i + 1;
jumpC[0][i][j] = j;
} else {
jumpR[0][i][j] = i;
jumpC[0][i][j] = j + 1;
}
}
}
for (int i = 1; i <= n + 1; i++) {
jumpR[0][n + 1][i] = n + 1;
jumpC[0][n + 1][i] = i;
jumpR[0][i][n + 1] = i;
jumpC[0][i][n + 1] = n + 1;
}
for (int k = 1; k < LOGN; k++) {
for (int i = 1; i <= n + 1; i++) {
for (int j = 1; j <= n + 1; j++) {
int midR = jumpR[k - 1][i][j];
int midC = jumpC[k - 1][i][j];
jumpR[k][i][j] = jumpR[k - 1][midR][midC];
jumpC[k][i][j] = jumpC[k - 1][midR][midC];
}
}
}
int q = Integer.parseInt(br.readLine());
StringBuilder sb = new StringBuilder();
for (int i = 0; i < q; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
int r1 = Integer.parseInt(st.nextToken());
int c1 = Integer.parseInt(st.nextToken());
int r2 = Integer.parseInt(st.nextToken());
int c2 = Integer.parseInt(st.nextToken());
int curR = r1;
int curC = c1;
for (int k = LOGN - 1; k >= 0; k--) {
int nextR = jumpR[k][curR][curC];
int nextC = jumpC[k][curR][curC];
if (nextR <= r2 && nextC <= c2) {
curR = nextR;
curC = nextC;
}
}
sb.append(curR).append(" ").append(curC).append("\n");
}
System.out.print(sb.toString());
}
}
import sys
def main():
n_str = sys.stdin.readline()
if not n_str: return
n = int(n_str)
grid = []
for _ in range(n):
grid.append(list(map(int, sys.stdin.readline().split())))
LOGN = (n * 2).bit_length()
jump = [[[(0, 0) for _ in range(n + 2)] for _ in range(n + 2)] for _ in range(LOGN)]
for r in range(1, n + 1):
for c in range(1, n + 1):
if grid[r - 1][c - 1] == 0:
jump[0][r][c] = (r + 1, c)
else:
jump[0][r][c] = (r, c + 1)
for i in range(1, n + 2):
jump[0][n + 1][i] = (n + 1, i)
jump[0][i][n + 1] = (i, n + 1)
for k in range(1, LOGN):
for r in range(1, n + 2):
for c in range(1, n + 2):
mid_r, mid_c = jump[k - 1][r][c]
jump[k][r][c] = jump[k - 1][mid_r][mid_c]
q_str = sys.stdin.readline()
if not q_str: return
q = int(q_str)
for _ in range(q):
r1, c1, r2, c2 = map(int, sys.stdin.readline().split())
cur_r, cur_c = r1, c1
for k in range(LOGN - 1, -1, -1):
next_r, next_c = jump[k][cur_r][cur_c]
if next_r <= r2 and next_c <= c2:
cur_r, cur_c = next_r, next_c
sys.stdout.write(f"{cur_r} {cur_c}\n")
if __name__ == "__main__":
main()
算法及复杂度
- 算法: 倍增 (Binary Lifting)
- 时间复杂度:
- 预处理
jump
表的时间为。
- 每次查询使用倍增法的时间为
,共
次查询。
- 预处理
- 空间复杂度:
,用于存储
jump
表。