题目链接

棋盘

题目描述

在一个 的棋盘上,每个格子写有数字 0 或 1。当小球在某个格子上时,如果数字是 0,它会向下移动一格;如果是 1,它会向右移动一格。对于给定的 个查询,每个查询指定一个子矩阵范围和一个起始点,要求计算小球从该起始点出发,在子矩阵内滚动,最终会从哪个格子滚出子矩阵。

解题思路

对于每个查询,一个朴素的方法是根据规则一步步模拟小球的移动,直到它越出子矩阵边界。对于一个 的棋盘,单次模拟最坏情况下可能需要 步。考虑到查询数量 可能很大(高达 ), 的总复杂度可能会超时。

为了优化这个过程,我们可以注意到小球的移动路径是完全确定的。从任何一个格子 出发,它的下一个位置是唯一的。这种性质非常适合使用**倍增(Binary Lifting)**算法来加速。

倍增算法的核心思想是预处理出从每个点出发,走 步后会到达的位置,从而将模拟过程从“一步一步走”加速为“大步跳跃”。

1. 预处理

我们创建一个三维数组 jump[k][r][c],用于存储从格子 出发,经过 次移动后到达的位置。

  • 基础状态 (k=0): jump[0][r][c] 表示从 移动 1 步到达的位置。这可以根据棋盘上 grid[r][c] 的值直接确定:

    • 如果 grid[r][c] == 0,则 jump[0][r][c] = (r+1, c)
    • 如果 grid[r][c] == 1,则 jump[0][r][c] = (r, c+1)
  • 递推关系 (k>0): 移动 步等价于先移动 步,到达一个中间位置,然后再从那个中间位置移动 步。因此,递推公式为: jump[k][r][c] = jump[k-1][ jump[k-1][r][c] ]

整个预处理过程的时间复杂度为

2. 查询

对于每个查询,给定子矩阵 和起点 ,我们的目标是找到路径上最后一个在子矩阵内的格子。

我们可以从起点开始,尝试进行大步的跳跃。为了确保不错过最终答案,我们应该从最大的步长(最大的 )开始尝试递减地跳跃。

  • 设当前位置为 pos = (r, c),初始时 pos = (r1, c1)
  • k = max_logN 向下循环到 0:
    • 计算尝试跳跃 步后的目标位置 next_pos = jump[k][pos.r][pos.c]
    • 如果 next_pos 仍然在子矩阵的边界内(即 next_pos.r <= r2next_pos.c <= c2),说明这次跳跃是安全的,我们可以执行它。更新 pos = next_pos
    • 如果 next_pos 会跳出子矩阵,则放弃这次跳跃,尝试更小的步长(即更小的 )。

当循环结束后,pos 就是小球滚出子矩阵前的最后一个位置。每次查询的时间复杂度为

代码

#include <iostream>
#include <vector>
#include <cmath>

using namespace std;

const int MAXN = 1001;
const int LOGN = 11; // ceil(log2(1001))

pair<int, int> jump[LOGN][MAXN][MAXN];
int grid[MAXN][MAXN];

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= n; ++j) {
            cin >> grid[i][j];
        }
    }

    // Precomputation - Step 1 (k=0)
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= n; ++j) {
            if (grid[i][j] == 0) {
                jump[0][i][j] = {i + 1, j};
            } else {
                jump[0][i][j] = {i, j + 1};
            }
        }
    }
     for (int i = 1; i <= n + 1; ++i) {
        jump[0][n + 1][i] = {n + 1, i};
        jump[0][i][n + 1] = {i, n + 1};
    }


    // Precomputation - Step 2 (k>0)
    for (int k = 1; k < LOGN; ++k) {
        for (int i = 1; i <= n + 1; ++i) {
            for (int j = 1; j <= n + 1; ++j) {
                pair<int, int> mid = jump[k - 1][i][j];
                jump[k][i][j] = jump[k - 1][mid.first][mid.second];
            }
        }
    }

    int q;
    cin >> q;
    while (q--) {
        int r1, c1, r2, c2;
        cin >> r1 >> c1 >> r2 >> c2;

        int cur_r = r1, cur_c = c1;

        for (int k = LOGN - 1; k >= 0; --k) {
            pair<int, int> next_pos = jump[k][cur_r][cur_c];
            if (next_pos.first <= r2 && next_pos.second <= c2) {
                cur_r = next_pos.first;
                cur_c = next_pos.second;
            }
        }
        cout << cur_r << " " << cur_c << "\n";
    }

    return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.StringTokenizer;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());

        int[][] grid = new int[n + 1][n + 1];
        for (int i = 1; i <= n; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            for (int j = 1; j <= n; j++) {
                grid[i][j] = Integer.parseInt(st.nextToken());
            }
        }

        int LOGN = (int) Math.ceil(Math.log(n * 2) / Math.log(2)) + 1;
        int[][][] jumpR = new int[LOGN][n + 2][n + 2];
        int[][][] jumpC = new int[LOGN][n + 2][n + 2];

        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                if (grid[i][j] == 0) {
                    jumpR[0][i][j] = i + 1;
                    jumpC[0][i][j] = j;
                } else {
                    jumpR[0][i][j] = i;
                    jumpC[0][i][j] = j + 1;
                }
            }
        }
        for (int i = 1; i <= n + 1; i++) {
            jumpR[0][n + 1][i] = n + 1;
            jumpC[0][n + 1][i] = i;
            jumpR[0][i][n + 1] = i;
            jumpC[0][i][n + 1] = n + 1;
        }

        for (int k = 1; k < LOGN; k++) {
            for (int i = 1; i <= n + 1; i++) {
                for (int j = 1; j <= n + 1; j++) {
                    int midR = jumpR[k - 1][i][j];
                    int midC = jumpC[k - 1][i][j];
                    jumpR[k][i][j] = jumpR[k - 1][midR][midC];
                    jumpC[k][i][j] = jumpC[k - 1][midR][midC];
                }
            }
        }

        int q = Integer.parseInt(br.readLine());
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < q; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int r1 = Integer.parseInt(st.nextToken());
            int c1 = Integer.parseInt(st.nextToken());
            int r2 = Integer.parseInt(st.nextToken());
            int c2 = Integer.parseInt(st.nextToken());

            int curR = r1;
            int curC = c1;

            for (int k = LOGN - 1; k >= 0; k--) {
                int nextR = jumpR[k][curR][curC];
                int nextC = jumpC[k][curR][curC];
                if (nextR <= r2 && nextC <= c2) {
                    curR = nextR;
                    curC = nextC;
                }
            }
            sb.append(curR).append(" ").append(curC).append("\n");
        }
        System.out.print(sb.toString());
    }
}
import sys

def main():
    n_str = sys.stdin.readline()
    if not n_str: return
    n = int(n_str)
    
    grid = []
    for _ in range(n):
        grid.append(list(map(int, sys.stdin.readline().split())))

    LOGN = (n * 2).bit_length()

    jump = [[[(0, 0) for _ in range(n + 2)] for _ in range(n + 2)] for _ in range(LOGN)]

    for r in range(1, n + 1):
        for c in range(1, n + 1):
            if grid[r - 1][c - 1] == 0:
                jump[0][r][c] = (r + 1, c)
            else:
                jump[0][r][c] = (r, c + 1)
    
    for i in range(1, n + 2):
        jump[0][n + 1][i] = (n + 1, i)
        jump[0][i][n + 1] = (i, n + 1)


    for k in range(1, LOGN):
        for r in range(1, n + 2):
            for c in range(1, n + 2):
                mid_r, mid_c = jump[k - 1][r][c]
                jump[k][r][c] = jump[k - 1][mid_r][mid_c]

    q_str = sys.stdin.readline()
    if not q_str: return
    q = int(q_str)
    
    for _ in range(q):
        r1, c1, r2, c2 = map(int, sys.stdin.readline().split())
        
        cur_r, cur_c = r1, c1
        
        for k in range(LOGN - 1, -1, -1):
            next_r, next_c = jump[k][cur_r][cur_c]
            if next_r <= r2 and next_c <= c2:
                cur_r, cur_c = next_r, next_c
                
        sys.stdout.write(f"{cur_r} {cur_c}\n")

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法: 倍增 (Binary Lifting)
  • 时间复杂度:
    • 预处理 jump 表的时间为
    • 每次查询使用倍增法的时间为 ,共 次查询。
  • 空间复杂度: ,用于存储 jump 表。