后缀数组、单调栈

题意:

图片说明

分析:

题意十分的清爽,但是却让人一筹莫
仔细地分析我们便会发现,我们可以利用后缀数组来进行求解。我们比对s1和s2的每一个后缀。然后计算其和就可以了。
对不对,其实题目中的也就是这个意思而已罢了。
很简单,我们可以这样构造一个字符串s1+'$'+s2,我们求其SA和rmq。然后进行枚举即可。
但是致命的就是没几句这一步。我们必须枚举s1.size()*s2.size()次。百分百tle

我们必须要寻找其它的办法代替枚举。
我们从求解两后缀最长公共前缀的过程开始。
为 min(height[i,...j])
也就是一个区间内的最小值。
那么假如我们要求字符串s的所有的后缀的相互公共前缀之和其实就是求所有区间内最小值之和。
但是,题目中球的不是所有后缀的最长公共前缀之和。而是s1的后缀和s2的后缀的最长公共前缀之和。
很简单,我们求一下s1+'$'+s2的所有后缀的最长充公前缀之和。再减s1的所有后缀的最长公共前缀之和与s2的所有后缀的最长公共前缀之和就可以了。 稍微有点绕哈。

那么我们考虑求一个字符串的所有后缀相互之间的最长公共前缀之和,即求这个字符串对应的height数组所有区间的最小值之和
如何求呢,所有区间的最小值之和?(height区间从索引1开始,即不包括0。因为根据height的性质height[0]是没有意义的)

这里便要用到名为单调栈的数据结构了。
这是我第一次接触单调栈。
何为单调栈?顾名思义,我们遍历height的索引从头到尾。
维护一个单调递增或者单调递减的序列,这里的单调指的是height,我们向单调栈中push的是索引

这里我们维护一个单调递增栈
遍历到索引i,如果height[stack.top()]<=height[i]那么我们直接push(i)
如果height[stack.top()]>height[i]那么我们就pop()
直到height[stack.top()]<=height[i]后者栈空了。我们再push(i)停下来。

我们要在这个过程中求出所有区间的最小值之和(包括[i,i]的形式,仔细回想如何求最长公共前缀)。
先说说想法,我们push(i),就计算出以i为右端点所有的区间的最小值之和。
我们想想单调递增栈,在push的的过程中无非两种情况;
1.
图片说明
正好大于height[stack.top()]
2.
图片说明
小于height[]stack.top()]我们pop()
变成图片说明

我们会发现,stack中的两个元素间stack[i-1]与stack[i]间要不没有元素,要不有的元素都大于stack[i]对不.
这意味着ehight[stack[i]]就是最小值。

假设我们知道了以stack.top()为右端点的所有区间的最小值之和sum。那么我们要求以i为右端点的所有最小值之和时。
若height[i]>=height[stack.top()]
因为上面我们分析出来的性质,我们只要关心左端点在stack.top()与i的情况再加上sum就可以了。
因为如果我们关系左端点在stack.top()左边的情况的话,我们一定要经过stack.top()

所以我们就sum+=(i-stack.top())*height[i]就好了
如果height[i]<height[stack.top()]
那么根据上面画的图。我们只要不断pop并减去相应的其贡献就可以了。

代码如下:

#include<iostream>
#include<algorithm>
using namespace std;
#define re register
typedef long long ll;
const int max_n = 4e5 + 100;
int ranks[max_n], SA[max_n], height[max_n];
int wa[max_n], wb[max_n], wvarr[max_n], wsarr[max_n];
inline int cmp(int* r, int a, int b, int l) {
    return r[a] == r[b] && r[a + l] == r[b + l];
}
inline void get_sa(int* r, int* sa, int n, int m) {
    ++n;
    re int i, j, p, * x = wa, * y = wb, * t;
    for (i = 0; i < m; ++i) wsarr[i] = 0;
    for (i = 0; i < n; ++i) wsarr[x[i] = r[i]]++;
    for (i = 1; i < m; ++i) wsarr[i] += wsarr[i - 1];
    for (i = n - 1; i >= 0; --i) sa[--wsarr[x[i]]] = i;
    for (j = 1, p = 1; p < n; j <<= 1, m = p) {
        for (p = 0, i = n - j; i < n; ++i) y[p++] = i;
        for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
        for (i = 0; i < n; ++i) wvarr[i] = x[y[i]];
        for (i = 0; i < m; ++i) wsarr[i] = 0;
        for (i = 0; i < n; ++i) wsarr[wvarr[i]]++;
        for (i = 1; i < m; ++i) wsarr[i] += wsarr[i - 1];
        for (i = n - 1; i >= 0; --i) sa[--wsarr[wvarr[i]]] = y[i];
        for (t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; ++i)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
}
void get_height(int* r, int* sa, int n) {
    re int i, j, k = 0;
    for (i = 1; i <= n; ++i) ranks[sa[i]] = i;
    for (i = 0; i < n; height[ranks[i++]] = k)
        for (k ? k-- : 0, j = sa[ranks[i] - 1]; r[i + k] == r[j + k]; k++);
    return;
}

int a[max_n];
void init(int n) {
    fill(a, a + n + 3, 0);
    fill(ranks, ranks + n + 3, 0);
    fill(SA, SA + n + 3, 0);
    fill(height, height + n + 3, 0);
    fill(wa, wa + n + 3, 0);
    fill(wb, wb + n + 3, 0);
    fill(wsarr, wsarr + n + 3, 0);
    fill(wvarr, wvarr + n + 3, 0);
}
struct monostack {
    int tax[max_n];
    int pos = -1;
    inline void push(int u) {
        tax[++pos] = u;
    }inline int top() {
        return tax[pos];
    }inline void pop() {
        if (!empty())--pos;
    }inline bool empty() {
        return pos == -1;
    }inline void clear() {
        pos = -1;
    }inline int size() {
        return pos + 1;
    }inline int operator[](int i) {
        return tax[i];
    }
}stac;
ll calcu(string& s) {
    init(s.size());
    for (int i = 0;i < s.size();++i)a[i] = (s[i] - 'a' + 1);
    get_sa(a, SA, s.size(), 30);
    get_height(a, SA, s.size());
    stac.clear();
    ll ans = 0;
    ll sum = 0;
    stac.push(0);height[0] = 0;
    for (re int i = 1;i <= s.size();++i) {
        while (stac.size() >= 2 && height[stac.top()] > height[i]) {
            sum -= (ll)height[stac.top()] * ((ll)stac[stac.size() - 1] - stac[stac.size() - 2]);
            stac.pop();
        }stac.push(i);
        sum += (ll)height[stac.top()] * ((ll)stac[stac.size() - 1] - stac[stac.size() - 2]);
        ans += sum;
    }return ans;
}

int main() {
    ios::sync_with_stdio(0);
    string s1, s2;cin >> s1 >> s2;
    string s3 = s1 + (char)('z' + 1) + s2;
    cout << calcu(s3) - calcu(s1) - calcu(s2) << endl;
}