E题“不平衡的字符串”题解

涉及算法

  • 树状数组
  • 离散化

题目大意

题目链接https://ac.nowcoder.com/acm/contest/16832/E
给出一个长度为的字符串,以及个约束条件。
每个约束条件格式为, 表示要求字符中字符串中所占的比例满足. 保证,且对每个字符至多有一个约束条件。
求S的所有子串中,至少满足一个约束条件的子串的个数。
例如,给出长度为6的字符串aabaab和两个约束条件。
约束条件1: a 1 2 1 1
约束条件2: b 1 2 1 1
字符串S共有21个子串, 其中只有"ab", "ba", "baab" 不满足任何一个约束条件,其余17个子串至少满足一个约束条件。

题解

首先注意到, 约束条件要求单个字符所占的比例大于,同一个字符串中不可能有两个字符的比例大于,所以每个子串至多只能满足一个约束条件,不存在同一个子串同时满足多个约束条件的情况。因此,只要依次计算出满足每个约束条件的子串数目,直接累加就能不重不漏地统计出“至少满足一个约束条件的子串数目”。

现在依次考虑每个约束条件,要求字符在子串中所占的比例在。用表示在字符串的前个字符中字符的个数。
考虑由的第i的字符到第j个字符构成的子串,该子串长度为, 其中字符的个数为, 故的比例。约束条件可以写成
$\frac{p_j - p_{i - 1}}{j - (i - 1)} \le \frac cdc(i - 1) - dp_{i - 1} \le cj - dp_jf_i = ci - dp_if_{i - 1} \le f_j$
可以使用权值树状数组来维护满足上式的子串数量,具体操作如下

  1. 因为值可能很大,且可能出现负数,需要先将离散化
  2. 依次扫描每个:
    • 向结果累加sum()
    • 将树状数组中位置的值加1, add(, 1)

这样得到的结果就是满足的子串个数, 再用相同的方式求出满足的子串个数,相减即可得到满足约束条件的子串个数。具体实现细节请参考代码。

代码

#include <bits/stdc++.h>

using namespace std;

const int maxn = 50000 + 100;

char s[maxn];
int p[maxn];

int n;

struct Tree {
    inline int lowbit(int x) {
        return x & -x;
    }
    int C[maxn];

    void add(int p) {
        for (; p <= n; p += lowbit(p))
            C[p]++;
    }

    int sum(int p) {
        int res = 0;
        for (; p; p -= lowbit(p))
            res += C[p];
        return res;
    }
} T1, T2;

int val1[maxn], val2[maxn];
vector<int> nums;

int main() {

    scanf("%d", &n);
    scanf("%s", s + 1);

    int m;
    scanf("%d", &m);
    char ch;
    int a, b, c, d;

    long long ans = 0;
    while (m--) {
        getchar();
        scanf("%c%d%d%d%d", &ch, &a, &b, &c, &d);

        for (int i = 1; i <= n; i++)
            p[i] = p[i - 1] + (s[i] == ch);

        // val1[i] = ai - b * p[i] , 离散化
        for (int i = 0; i <= n; i++) {
            val1[i] = a * i - b * p[i];
            nums.push_back(val1[i]);
        }
        sort(nums.begin(), nums.end());
        nums.erase(unique(nums.begin(), nums.end()), nums.end());
        for (int i = 0; i <= n; i++)
            val1[i] = lower_bound(nums.begin(), nums.end(), val1[i]) - nums.begin() + 1;
        nums.clear();

        // val2[i] = ci - d * p[i]
        for (int i = 0; i <= n; i++) {
            val2[i] = c * i - d * p[i];
            nums.push_back(val2[i]);
        }

        sort(nums.begin(), nums.end());
        nums.erase(unique(nums.begin(), nums.end()), nums.end());
        for (int i = 0; i <= n; i++)
            val2[i] = lower_bound(nums.begin(), nums.end(), val2[i]) - nums.begin() + 1;
        nums.clear();

        memset(T1.C, 0, 4 * maxn);
        memset(T2.C, 0, 4 * maxn);
        T1.add(val1[0]);
        T2.add(val2[0]);
        for (int i = 1; i <= n; i++) {
            int t1 = T1.sum(val1[i]);
            int t2 = T2.sum(val2[i]);
            ans += t2 - t1;
            T1.add(val1[i]);
            T2.add(val2[i]);
        }
    }

    printf("%lld\n", ans);

    return 0;
}