题解:BISHI123 环形字符串跃迁
题目链接
题目描述
给定长度为 的环形二进制字符串
与参数
。光标从位置
出发进行
次跃迁:若在其后方不含自身的
个字符中存在字符
,则直接跳到这些
中最远的一个;否则跳到后方一个字符(环形)。共
次询问,每次给出
,问终点所在下标。
解题思路
将“单步跳转”关系预处理成函数 :表示从
出发一次跃迁到达的位置。随后对该函数做二进制倍增(跳表),即可在
时间回答每次询问。
关键是 的线性预处理:将
复制一遍得到
的长度
串,做“最近的前一个
”数组
(
)。对于原串位置
,窗口为
,最远的
就是
,若
则
;若没有(
),则
。注意所有下标均为
基。
整体复杂度:预处理 ,单次查询
,其中
。
代码
#include <bits/stdc++.h>
using namespace std;
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, k, q;
if(!(cin >> n >> k >> q)) return 0;
string s; cin >> s; // 0-indexed string of length n
// build doubled string 1..2n
vector<int> pre0(2*n + 1, 0);
for(int i = 1; i <= 2*n; ++i){
char ch = s[(i-1)%n];
if(ch == '0') pre0[i] = i; else pre0[i] = pre0[i-1];
}
// next transition (1-based)
vector<int> nxt(n+1);
for(int i = 1; i <= n; ++i){
long long R = i + 1LL * k;
if(R > 2LL*n) R = 2LL*n; // 保守截断(题目通常 k<=n)
int j = pre0[(int)R];
if(j > i) nxt[i] = (j-1)%n + 1;
else nxt[i] = (i % n) + 1;
}
// read queries, collect max m
vector<pair<int, long long>> qs(q);
long long maxm = 0;
for(int i=0;i<q;++i){
int p; long long m; cin >> p >> m; qs[i] = {p, m}; maxm = max(maxm, m);
}
int LOG = 0; while((1LL<<LOG) <= maxm) ++LOG;
vector<vector<int>> up(LOG, vector<int>(n+1));
for(int i=1;i<=n;++i) up[0][i] = nxt[i];
for(int b=1;b<LOG;++b){
for(int i=1;i<=n;++i) up[b][i] = up[b-1][ up[b-1][i] ];
}
// answer
for(auto [p, m] : qs){
int cur = p;
for(int b=0; m; ++b){
if(m & 1) cur = up[b][cur];
m >>= 1;
}
cout << cur << '\n';
}
return 0;
}
import java.io.*;
public class Main {
static class FastScanner {
private final InputStream in; private final byte[] buf = new byte[1<<16];
private int p=0,l=0; FastScanner(InputStream is){in=is;}
private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
String next() throws IOException { int c; StringBuilder sb=new StringBuilder(); do{c=read();}while(c<=32); while(c>32){ sb.append((char)c); c=read(); } return sb.toString(); }
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner(System.in);
int n = fs.nextInt();
int k = fs.nextInt();
int q = fs.nextInt();
String s = fs.next();
int[] pre0 = new int[2*n + 1];
for(int i=1;i<=2*n;i++){
char ch = s.charAt((i-1)%n);
pre0[i] = (ch=='0') ? i : pre0[i-1];
}
int[] nxt = new int[n+1];
for(int i=1;i<=n;i++){
long R = (long)i + k;
if(R > 2L*n) R = 2L*n; // 通常 k<=n
int j = pre0[(int)R];
if(j > i) nxt[i] = (j-1)%n + 1; else nxt[i] = (i % n) + 1;
}
int[] P = new int[q]; long[] M = new long[q]; long maxm = 0;
for(int i=0;i<q;i++){ P[i]=fs.nextInt(); M[i]=fs.nextLong(); if(M[i]>maxm) maxm=M[i]; }
int LOG = 0; while((1L<<LOG) <= maxm) ++LOG;
int[][] up = new int[LOG][n+1];
for(int i=1;i<=n;i++) up[0][i]=nxt[i];
for(int b=1;b<LOG;b++) for(int i=1;i<=n;i++) up[b][i]=up[b-1][ up[b-1][i] ];
StringBuilder out = new StringBuilder();
for(int i=0;i<q;i++){
int cur = P[i]; long m = M[i]; int b=0;
while(m>0){ if((m&1L)==1L) cur = up[b][cur]; m >>= 1; b++; }
out.append(cur).append('\n');
}
System.out.print(out.toString());
}
}
import sys
data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); k = int(next(it)); q = int(next(it))
s = next(it).decode()
# build pre0 on doubled string
pre0 = [0]*(2*n + 1)
for i in range(1, 2*n + 1):
ch = s[(i-1) % n]
pre0[i] = i if ch == '0' else pre0[i-1]
nxt = [0]*(n+1)
for i in range(1, n+1):
R = i + k
if R > 2*n:
R = 2*n # 通常 k<=n
j = pre0[R]
nxt[i] = (j-1) % n + 1 if j > i else (i % n) + 1
qs = []
maxm = 0
for _ in range(q):
p = int(next(it)); m = int(next(it))
qs.append((p, m))
if m > maxm: maxm = m
LOG = 0
while (1 << LOG) <= maxm:
LOG += 1
up = [[0]*(n+1) for _ in range(LOG)]
for i in range(1, n+1):
up[0][i] = nxt[i]
for b in range(1, LOG):
row_prev = up[b-1]
row = up[b]
for i in range(1, n+1):
row[i] = row_prev[row_prev[i]]
out_lines = []
for p, m in qs:
cur = p
b = 0
while m:
if m & 1:
cur = up[b][cur]
m >>= 1
b += 1
out_lines.append(str(cur))
sys.stdout.write('\n'.join(out_lines))
算法及复杂度
- 算法:预处理单步跳转 + 二进制倍增回答多次跳转
- 时间复杂度:
,
- 空间复杂度: