题意
给你一个长度为n的字符串与q个查询
对于每一次查询,问当区间为时翻转,并求出整个字符串中有多少个red子序列(每一次操作相互独立互不影响)
思路
使用表分别维护从下标i开始长度为
的r,re,ed,red与de,er,der的个数(其实可以开成一个数组但是我智障了)
这样可以得到任意一个区间中所有子串的个数
代码
#include <bits/stdc++.h>
using namespace std;
typedef long ll;
typedef unsigned long long ull;
const ll N = 1e5 + 5;
const ll mod = 1e9 + 7;
typedef double db;
const double eps = 1e-6;
#define endl '\n'
#define PII pair<ll, ll>
#define PIII array<ll, 3>
#define fi first
#define se second
#define R 1
#define E 2
#define D 3
#define RE 4
#define ED 5
#define RED 6
#define DER 6
#define ER 5
#define DE 4
// 我的哈希值 *131 +23317
ll arr[N];
ll n, m, k;
ll ans;
ll q;
string s;
ll dp[N][18][7];
ll ndp[N][18][7];
void init()
{
for (ll bit = 0; bit < 18; bit++)
{
for (ll i = 1; i <= n; i++)
{
if (i + (1ll << bit) - 1 > n)
break;
if (bit == 0)
{
if (s[i] == 'r')
dp[i][bit][R]++;
if (s[i] == 'e')
dp[i][bit][E]++;
if (s[i] == 'd')
dp[i][bit][D]++;
}
else
{
dp[i][bit][RED] += dp[i][bit - 1][RED];
dp[i][bit][RE] += dp[i][bit - 1][RE];
dp[i][bit][ED] += dp[i][bit - 1][ED];
dp[i][bit][R] += dp[i][bit - 1][R];
dp[i][bit][E] += dp[i][bit - 1][E];
dp[i][bit][D] += dp[i][bit - 1][D];
ll j = i + (1ll << (bit - 1));
dp[i][bit][RED] += dp[i][bit][R] * dp[j][bit - 1][ED] + dp[i][bit][RE] * dp[j][bit - 1][D] + dp[j][bit - 1][RED];
dp[i][bit][RE] += dp[i][bit][R] * dp[j][bit - 1][E] + dp[j][bit - 1][RE];
dp[i][bit][ED] += dp[i][bit][E] * dp[j][bit - 1][D] + dp[j][bit - 1][ED];
dp[i][bit][R] += dp[j][bit - 1][R];
dp[i][bit][E] += dp[j][bit - 1][E];
dp[i][bit][D] += dp[j][bit - 1][D];
}
}
}
for (ll bit = 0; bit < 18; bit++)
{
for (ll i = 1; i <= n; i++)
{
if (i + (1ll << bit) - 1 > n)
break;
if (bit == 0)
{
if (s[i] == 'd')
ndp[i][bit][D]++;
if (s[i] == 'e')
ndp[i][bit][E]++;
if (s[i] == 'r')
ndp[i][bit][R]++;
}
else
{
ndp[i][bit][DER] += ndp[i][bit - 1][DER];
ndp[i][bit][DE] += ndp[i][bit - 1][DE];
ndp[i][bit][ER] += ndp[i][bit - 1][ER];
ndp[i][bit][R] += ndp[i][bit - 1][R];
ndp[i][bit][E] += ndp[i][bit - 1][E];
ndp[i][bit][D] += ndp[i][bit - 1][D];
ll j = i + (1ll << (bit - 1));
ndp[i][bit][DER] += ndp[i][bit][D] * ndp[j][bit - 1][ER] + ndp[i][bit][DE] * ndp[j][bit - 1][R] + ndp[j][bit - 1][DER];
ndp[i][bit][DE] += ndp[i][bit][D] * ndp[j][bit - 1][E] + ndp[j][bit - 1][DE];
ndp[i][bit][ER] += ndp[i][bit][E] * ndp[j][bit - 1][R] + ndp[j][bit - 1][ER];
ndp[i][bit][R] += ndp[j][bit - 1][R];
ndp[i][bit][E] += ndp[j][bit - 1][E];
ndp[i][bit][D] += ndp[j][bit - 1][D];
}
}
}
}
ll fd(ll x, ll l, ll r)
{
if (r < l)
return 0;
vector<ll> temp(8, 0);
ll num = r - l + 1;
ll now = l;
for (ll bit = 17; bit >= 0; bit--)
{
if (num >> bit & 1)
{
ll j = now;
temp[RED] += temp[R] * dp[j][bit][ED] + temp[RE] * dp[j][bit][D] + dp[j][bit][RED];
temp[RE] += temp[R] * dp[j][bit][E] + dp[j][bit][RE];
temp[ED] += temp[E] * dp[j][bit][D] + dp[j][bit][ED];
temp[R] += dp[j][bit][R];
temp[E] += dp[j][bit][E];
temp[D] += dp[j][bit][D];
now += 1ll << bit;
}
}
return temp[x];
}
ll nfd(ll x, ll l, ll r)
{
if (r < l)
return 0;
vector<ll> temp(8, 0);
ll num = r - l + 1;
ll now = l;
for (ll bit = 17; bit >= 0; bit--)
{
if (num >> bit & 1)
{
// cout << bit << endl;
ll j = now;
temp[DER] += temp[D] * ndp[j][bit][ER] + temp[DE] * ndp[j][bit][R] + ndp[j][bit][DER];
temp[DE] += temp[D] * ndp[j][bit][E] + ndp[j][bit][DE];
temp[ER] += temp[E] * ndp[j][bit][R] + ndp[j][bit][ER];
temp[R] += ndp[j][bit][R];
temp[E] += ndp[j][bit][E];
temp[D] += ndp[j][bit][D];
now += 1ll << bit;
}
}
return temp[x];
}
void solve()
{
cin >> n >> k;
cin >> s;
s = " " + s;
init();
// cout << dp[1][1][R] << endl;
// cout << nfd(E, 1, 2);
for (ll i = 1; i <= k; i++)
{
long long ans = 0;
ll x, y;
cin >> x >> y;
ans += fd(RED, 1, x - 1);
ans += nfd(DER, x, y);
ans += fd(RED, y + 1, n);
ans += 1ll * fd(R, 1, x - 1) * fd(E, x, y) * fd(D, y + 1, n);
ans += 1ll * fd(RE, 1, x - 1) * fd(D, x, y);
ans += 1ll * nfd(ER, x, y) * fd(D, y + 1, n);
ans += 1ll * fd(R, 1, x - 1) * nfd(DE, x, y);
ans += 1ll * fd(R, x, y) * fd(ED, y + 1, n);
ans += 1ll * fd(RE, 1, x - 1) * fd(D, y + 1, n);
ans += 1ll * fd(R, 1, x - 1) * fd(ED, y + 1, n);
cout << ans << endl;
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
ll t = 1; // cin>>t;
while (t--)
solve();
return 0;
}