题解:BISHI123 环形字符串跃迁

题目链接

环形字符串跃迁

题目描述

给定长度为 的环形二进制字符串 与参数 。光标从位置 出发进行 次跃迁:若在其后方不含自身的 个字符中存在字符 ,则直接跳到这些 中最远的一个;否则跳到后方一个字符(环形)。共 次询问,每次给出 ,问终点所在下标。

解题思路

将“单步跳转”关系预处理成函数 :表示从 出发一次跃迁到达的位置。随后对该函数做二进制倍增(跳表),即可在 时间回答每次询问。

关键是 的线性预处理:将 复制一遍得到 的长度 串,做“最近的前一个 ”数组 )。对于原串位置 ,窗口为 ,最远的 就是 ,若 ;若没有(),则 。注意所有下标均为 基。

整体复杂度:预处理 ,单次查询 ,其中


代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, k, q;
    if(!(cin >> n >> k >> q)) return 0;
    string s; cin >> s; // 0-indexed string of length n
    // build doubled string 1..2n
    vector<int> pre0(2*n + 1, 0);
    for(int i = 1; i <= 2*n; ++i){
        char ch = s[(i-1)%n];
        if(ch == '0') pre0[i] = i; else pre0[i] = pre0[i-1];
    }
    // next transition (1-based)
    vector<int> nxt(n+1);
    for(int i = 1; i <= n; ++i){
        long long R = i + 1LL * k;
        if(R > 2LL*n) R = 2LL*n; // 保守截断(题目通常 k<=n)
        int j = pre0[(int)R];
        if(j > i) nxt[i] = (j-1)%n + 1;
        else nxt[i] = (i % n) + 1;
    }
    // read queries, collect max m
    vector<pair<int, long long>> qs(q);
    long long maxm = 0;
    for(int i=0;i<q;++i){
        int p; long long m; cin >> p >> m; qs[i] = {p, m}; maxm = max(maxm, m);
    }
    int LOG = 0; while((1LL<<LOG) <= maxm) ++LOG;
    vector<vector<int>> up(LOG, vector<int>(n+1));
    for(int i=1;i<=n;++i) up[0][i] = nxt[i];
    for(int b=1;b<LOG;++b){
        for(int i=1;i<=n;++i) up[b][i] = up[b-1][ up[b-1][i] ];
    }
    // answer
    for(auto [p, m] : qs){
        int cur = p;
        for(int b=0; m; ++b){
            if(m & 1) cur = up[b][cur];
            m >>= 1;
        }
        cout << cur << '\n';
    }
    return 0;
}
import java.io.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
        long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
        String next() throws IOException { int c; StringBuilder sb=new StringBuilder(); do{c=read();}while(c<=32); while(c>32){ sb.append((char)c); c=read(); } return sb.toString(); }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        int k = fs.nextInt();
        int q = fs.nextInt();
        String s = fs.next();
        int[] pre0 = new int[2*n + 1];
        for(int i=1;i<=2*n;i++){
            char ch = s.charAt((i-1)%n);
            pre0[i] = (ch=='0') ? i : pre0[i-1];
        }
        int[] nxt = new int[n+1];
        for(int i=1;i<=n;i++){
            long R = (long)i + k;
            if(R > 2L*n) R = 2L*n; // 通常 k<=n
            int j = pre0[(int)R];
            if(j > i) nxt[i] = (j-1)%n + 1; else nxt[i] = (i % n) + 1;
        }
        int[] P = new int[q]; long[] M = new long[q]; long maxm = 0;
        for(int i=0;i<q;i++){ P[i]=fs.nextInt(); M[i]=fs.nextLong(); if(M[i]>maxm) maxm=M[i]; }
        int LOG = 0; while((1L<<LOG) <= maxm) ++LOG;
        int[][] up = new int[LOG][n+1];
        for(int i=1;i<=n;i++) up[0][i]=nxt[i];
        for(int b=1;b<LOG;b++) for(int i=1;i<=n;i++) up[b][i]=up[b-1][ up[b-1][i] ];
        StringBuilder out = new StringBuilder();
        for(int i=0;i<q;i++){
            int cur = P[i]; long m = M[i]; int b=0;
            while(m>0){ if((m&1L)==1L) cur = up[b][cur]; m >>= 1; b++; }
            out.append(cur).append('\n');
        }
        System.out.print(out.toString());
    }
}
import sys

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); k = int(next(it)); q = int(next(it))
s = next(it).decode()

# build pre0 on doubled string
pre0 = [0]*(2*n + 1)
for i in range(1, 2*n + 1):
    ch = s[(i-1) % n]
    pre0[i] = i if ch == '0' else pre0[i-1]

nxt = [0]*(n+1)
for i in range(1, n+1):
    R = i + k
    if R > 2*n:
        R = 2*n  # 通常 k<=n
    j = pre0[R]
    nxt[i] = (j-1) % n + 1 if j > i else (i % n) + 1

qs = []
maxm = 0
for _ in range(q):
    p = int(next(it)); m = int(next(it))
    qs.append((p, m))
    if m > maxm: maxm = m

LOG = 0
while (1 << LOG) <= maxm:
    LOG += 1
up = [[0]*(n+1) for _ in range(LOG)]
for i in range(1, n+1):
    up[0][i] = nxt[i]
for b in range(1, LOG):
    row_prev = up[b-1]
    row = up[b]
    for i in range(1, n+1):
        row[i] = row_prev[row_prev[i]]

out_lines = []
for p, m in qs:
    cur = p
    b = 0
    while m:
        if m & 1:
            cur = up[b][cur]
        m >>= 1
        b += 1
    out_lines.append(str(cur))

sys.stdout.write('\n'.join(out_lines))

算法及复杂度

  • 算法:预处理单步跳转 + 二进制倍增回答多次跳转
  • 时间复杂度:
  • 空间复杂度: