E题“不平衡的字符串”题解
涉及算法
- 树状数组
- 离散化
题目大意
题目链接https://ac.nowcoder.com/acm/contest/16832/E
给出一个长度为的字符串
,以及
个约束条件。
每个约束条件格式为, 表示要求字符
中字符串中所占的比例
满足
. 保证
,且对每个字符至多有一个约束条件。
求S的所有子串中,至少满足一个约束条件的子串的个数。
例如,给出长度为6的字符串aabaab和两个约束条件。
约束条件1: a 1 2 1 1
约束条件2: b 1 2 1 1
字符串S共有21个子串, 其中只有"ab", "ba", "baab" 不满足任何一个约束条件,其余17个子串至少满足一个约束条件。
题解
首先注意到, 约束条件要求单个字符所占的比例大于
,同一个字符串中不可能有两个字符的比例大于
,所以每个子串至多只能满足一个约束条件,不存在同一个子串同时满足多个约束条件的情况。因此,只要依次计算出满足每个约束条件的子串数目,直接累加就能不重不漏地统计出“至少满足一个约束条件的子串数目”。
现在依次考虑每个约束条件,要求字符在子串中所占的比例在
。用
表示在字符串
的前
个字符中字符
的个数。
考虑由的第i的字符到第j个字符构成的子串,该子串长度为
, 其中字符
的个数为
, 故
的比例
。约束条件可以写成
$\frac{p_j - p_{i - 1}}{j - (i - 1)} \le \frac cd
c(i - 1) - dp_{i - 1} \le cj - dp_j
f_i = ci - dp_i
f_{i - 1} \le f_j$
可以使用权值树状数组来维护满足上式的子串数量,具体操作如下
- 因为
值可能很大,且可能出现负数,需要先将
离散化
- 从
到
依次扫描每个
:
- 向结果累加sum(
)
- 将树状数组中位置
的值加1, add(
, 1)
- 向结果累加sum(
这样得到的结果就是满足的子串个数, 再用相同的方式求出满足
的子串个数,相减即可得到满足约束条件的子串个数。具体实现细节请参考代码。
代码
#include <bits/stdc++.h> using namespace std; const int maxn = 50000 + 100; char s[maxn]; int p[maxn]; int n; struct Tree { inline int lowbit(int x) { return x & -x; } int C[maxn]; void add(int p) { for (; p <= n; p += lowbit(p)) C[p]++; } int sum(int p) { int res = 0; for (; p; p -= lowbit(p)) res += C[p]; return res; } } T1, T2; int val1[maxn], val2[maxn]; vector<int> nums; int main() { scanf("%d", &n); scanf("%s", s + 1); int m; scanf("%d", &m); char ch; int a, b, c, d; long long ans = 0; while (m--) { getchar(); scanf("%c%d%d%d%d", &ch, &a, &b, &c, &d); for (int i = 1; i <= n; i++) p[i] = p[i - 1] + (s[i] == ch); // val1[i] = ai - b * p[i] , 离散化 for (int i = 0; i <= n; i++) { val1[i] = a * i - b * p[i]; nums.push_back(val1[i]); } sort(nums.begin(), nums.end()); nums.erase(unique(nums.begin(), nums.end()), nums.end()); for (int i = 0; i <= n; i++) val1[i] = lower_bound(nums.begin(), nums.end(), val1[i]) - nums.begin() + 1; nums.clear(); // val2[i] = ci - d * p[i] for (int i = 0; i <= n; i++) { val2[i] = c * i - d * p[i]; nums.push_back(val2[i]); } sort(nums.begin(), nums.end()); nums.erase(unique(nums.begin(), nums.end()), nums.end()); for (int i = 0; i <= n; i++) val2[i] = lower_bound(nums.begin(), nums.end(), val2[i]) - nums.begin() + 1; nums.clear(); memset(T1.C, 0, 4 * maxn); memset(T2.C, 0, 4 * maxn); T1.add(val1[0]); T2.add(val2[0]); for (int i = 1; i <= n; i++) { int t1 = T1.sum(val1[i]); int t2 = T2.sum(val2[i]); ans += t2 - t1; T1.add(val1[i]); T2.add(val2[i]); } } printf("%lld\n", ans); return 0; }