血淋淋的教训!!!!!!!

这一题我觉得出的很好。后缀数组+单调栈。

我们很容易想到两个字符串通过间隔符相连
height分组筛选出来一堆互相最长公共前缀大于等于K的后缀
但是如何统计他们之间的贡献却成为了一个难事情。
这里我们使用的技巧叫做 单调栈!!!

其实好久之前我就遇到单调栈了,但是当时没有钻研。。。。。。

分组之后,我们统计每一个区间中属于s2的后缀与他前面属于s1的后缀所造成的贡献
这样再扫一遍统计属于s1的后缀与他前面属于s2的后缀所造成的贡献。那就是答案了。

这两个问题是对称的,我们来看第一个问题。
假设我们对于s2的后缀i要统计他前面的s1的后缀对他的贡献

这是一个怎样的过程?
我们从i往前找,一路统计最小值,遇到s1的后缀就 ans+=(min-K+1)
对不对。
我们会发现,一个最小值可以对应多个s1的后缀。
也就是说,直到遇到下一个比我小的,我min始终不变,在此路上,遇到的s!的后缀的贡献全部都是min-K+1

那么我们可以使用单调栈,从上往下扫。
我们向栈中压一个二元组(min,cnt)
指 min(height)为min的s1的后缀有cnt个

每当我们扫描到下一个时看一下他的heigh[j]
然后从栈顶开始看,如果栈顶元素的min<height[j]记录下他的cnt然后把他pop掉
然后减去这个区间队后面s2的后缀做的贡献(min-K+1)cnt
就这样直到,栈顶元素的min>height[j]然后在压入height和我们一路记录下来的cnt之和
并加上这个区间中的s1的后缀队后面的s2的后缀做的贡献(height[j]-K+!)
cnt
这其实就是,到j了,之前的元素如果其min比height[j]自然就应该被替换成height[j]
很是巧妙,我真的佩服发明出单调栈和单调队列的人们!!!!!!!!!
我们会发现一进一出时间O(n)
真牛!!!!!!
真的是应该好好理解。
我们之所以维护一个单调递增栈而不是一个单调递减栈,其原因就是:
在从上往下遍历的过程中。后面出现的height[]小于之前的height[]不会影响
但是如果小于的话就会影响。反应在上述操作中的就是我们需要一路pop栈顶,直到栈顶元素的height[]小于当前height[]
即前面的不在受影响!!!!
真牛!!!!!!!!!!!简单但是巧妙

而我在哪里被坑了呢?
我是在求后缀数组时被坑的!!!
我是无论是字符还是数字都统统转化为数字求的后缀数组。
在多组数据时,竟然在求解height数组时会因为前面的数据而导致当前数据求解错误!!!!!
原因是在比对r过程中,前面的数据残留会导致错误。
在我后缀数组的模板中r中似不可以出现0的,所以为了避免前面数据残留在承德影响,我们需要a[n]=0
坑死人了!!!!!!WA了一天!!!呜呜呜呜呜。。。。。。还好不是正式比赛!!!!!:)

代码如下

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<stack>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
const int max_n = 2e5 + 100;
int a[max_n];
int ranks[max_n], SA[max_n], height[max_n];
int wa[max_n], wb[max_n], wvarr[max_n], wsarr[max_n];
inline int cmp(int* r, int a, int b, int l) {
    return r[a] == r[b] && r[a + l] == r[b + l];
}
inline void get_sa(int* r, int* sa, int n, int m) {
    ++n;
    int i, j, p, * x = wa, * y = wb, * t;
    for (i = 0; i < m; ++i) wsarr[i] = 0;
    for (i = 0; i < n; ++i) wsarr[x[i] = r[i]]++;
    for (i = 1; i < m; ++i) wsarr[i] += wsarr[i - 1];
    for (i = n - 1; i >= 0; --i) sa[--wsarr[x[i]]] = i;
    for (j = 1, p = 1; p < n; j <<= 1, m = p) {
        for (p = 0, i = n - j; i < n; ++i) y[p++] = i;
        for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
        for (i = 0; i < n; ++i) wvarr[i] = x[y[i]];
        for (i = 0; i < m; ++i) wsarr[i] = 0;
        for (i = 0; i < n; ++i) wsarr[wvarr[i]]++;
        for (i = 1; i < m; ++i) wsarr[i] += wsarr[i - 1];
        for (i = n - 1; i >= 0; --i) sa[--wsarr[wvarr[i]]] = y[i];
        for (t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; ++i)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
}
void get_height(int* r, int* sa, int n) {
    int i, j, k = 0;
    for (i = 0; i <= n; ++i) ranks[sa[i]] = i;
    for (i = 0; i < n; height[ranks[i++]] = k)
        for (k ? k-- : 0, j = sa[ranks[i] - 1]; r[i + k] == r[j + k]; k++);
    return;
}
int K;
char s1[max_n], s2[max_n];
int main() {
    while (scanf("%d", &K)) {
        if (K == 0)break;
        scanf("%s\n%s", &s1, &s2);
        int n1 = strlen(s1);
        int n2 = strlen(s2);
        for (int i = 0;i < n1;++i)
            a[i] = s1[i];
        a[n1] = 290;
        for (int i = n1 + 1;i < n1 + n2 + 1;++i)
            a[i] = s2[i - n1 - 1];
        a[n1 + n2 + 1] = 0;
        get_sa(a, SA, n1 + n2 + 1, 300);
        get_height(a, SA, n1 + n2 + 1);
        ll ans = 0;ll sum = 0;
        stack<pll> st;
        for (int i = 2;i <= n1 + n2 + 1;++i) {
            ll cnt = 0;
            while (!st.empty() && st.top().first >= height[i]) {
                cnt += st.top().second;
                sum -= (st.top().first - K + 1) * st.top().second;
                st.pop();
            }if (height[i] < K) {
                sum = 0;
                continue;
            }cnt += (SA[i - 1] < n1);
            if (cnt != 0)st.push(pll(height[i], cnt));
            sum += ((ll)height[i] - K + 1) * cnt;
            if (SA[i] > n1)ans += sum;
        }sum = 0;
        while (!st.empty())st.pop();
        for (int i = 2;i <= n1 + n2 + 1;++i) {
            ll cnt = 0;
            while (!st.empty() && st.top().first >= height[i]) {
                cnt += st.top().second;
                sum -= (st.top().first - K + 1) * st.top().second;
                st.pop();
            }if (height[i] < K) {
                sum = 0;
                continue;
            }cnt += (SA[i - 1] > n1);
            if (cnt != 0)st.push(pll(height[i], cnt));
            sum += ((ll)height[i] - K + 1) * cnt;
            if (SA[i] < n1)ans += sum;
        }printf("%lld\n", ans);
    }
}