这是一道倍增题目,因为k最大可以取到1e18,所以需要优化转化为二进制来计算
定义st[i][j]表示从i这个位置经行
次跃迁后到达的位置,那么st[i][j] = st[st[i][j - 1]][j - 1], 表示的是从st[i][j - 1](即先从i这个位置进行
次跃迁)然后再进行
次跃迁.那么就相当于经行了
次跃迁
观察st[i][j]的转移方程式,其实只需要求出所有的st[i][0],那么就可以推出所有的st[i][j]
st[i][0]表示从i位置进行
即1次跃迁的结果,求st[i][0]的具体方法是:
首先需要将字符串s扩大一倍,因为这是环形:
string s;
cin >> s;
vector<char> a(2 * n + 10);
for(int i = 1; i <= n; i ++){
a[i] = s[i - 1];
a[i + n] = s[i - 1];
}
然后将区间[1, 2n]里面所有0的位置存下来:
vector<int> zero;
for(int i = 1; i <= 2 * n; i ++){
if(a[i] == '0') zero.push_back(i);
}
现在就是求每一个st[i][0]:
从i跃迁一次,考虑的位置是[i + 1, i + m],首先需要求出这个区间里面最远的0的位置,这也是为什么要用vector将[1, 2n]中所有0位置存下来的原因:为了求跃迁范围里面最远的0位置
auto it = upper_bound(zero.begin(), zero.end(), r);
if(it != zero.begin()){
it --;
int pos = *it;
if(pos >= l){
if(pos > n) st[i][0] = pos - n;
else st[i][0] = pos;
}
else st[i][0] = (i % n) + 1;
}
else st[i][0] = (i % n) + 1;
it 是 求出zero中第一个比 i + m大的位置,那么:
如果it = zero.begin():表示zero中所有0的位置都比i + m大,意思就是[i + 1, i + m]中没有0,所以st[i][0]就等于下一个位置
如果it != zero.begin(),那么it --, 此时pos就是最后一个 小于等于 i + m 的0的位置,再如果pos
i + 1,那么在合法范围里面,此时还需要判断是否大于n ....;如果pos < i + 1, 那么也表示[i + 1, i + m]中没有0,所以st[i][0]就等于下一个位置。
for(int j = 1; j < M; j ++){
for(int i = 1; i <= n; i ++){
st[i][j] = st[st[i - 1][j - 1]][j - 1];
}
}
然后求每一个询问:
while(q --){
int t, k; cin >> t >> k;
int ans = t;
for(int i = M - 1; i >= 0; i --){
if(k >= (1LL << i)){
k -= (1LL << i);
ans = st[ans][i];
}
}
cout << ans << endl;
}
总代码:
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int long long
#define IOS ios::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define HelloWorld IOS;
const int N = 5e5 + 10;
const int M = 63;
int n, m, q;
int st[N][M];
string s;
signed main(){
HelloWorld;
cin >> n >> m >> q;
cin >> s;
vector<char> a(2 * n + 10);
for(int i = 1; i <= n; i ++){
a[i] = s[i - 1];
a[i + n] = s[i - 1];
}
vector<int> zero;
for(int i = 1; i <= 2 * n; i ++){
if(a[i] == '0') zero.push_back(i);
}
for(int i = 1; i <= n; i ++){
int l = i + 1, r = i + m;
auto it = upper_bound(zero.begin(), zero.end(), r);
if(it != zero.begin()){
it --;
int pos = *it;
if(pos >= l){
if(pos > n) st[i][0] = pos - n;
else st[i][0] = pos;
}
else st[i][0] = (i % n) + 1;
}
else st[i][0] = (i % n) + 1;
}
for(int j = 1; j < M; j ++){
for(int i = 1; i <= n; i ++){
st[i][j] = st[st[i][j - 1]][j - 1];
}
}
while(q --){
int t, k; cin >> t >> k;
int ans = t;
for(int i = M - 1; i >= 0; i --){
if(k >= (1LL << i)){
k -= (1LL << i);
ans = st[ans][i];
}
}
cout << ans << endl;
}
return 0;
}



京公网安备 11010502036488号