后缀数组做法

 后缀数组的做法一般都包括三个数组:sasa:所有后缀中字典序第 ii 大的是从位置 sa[i]sa[i] 开始的后缀; rankrank:位置 ii 开始的后缀在所有后缀中字典序排第 rank[i]rank[i]lcplcp:高度数组,字典序第 ii 大得缀与字典序第 i+1i+1 大的后缀的最长公共前缀(longest common prefix)。另外通过建立 lcplcp 上的 st表,可以求得任意两个后缀 xxyy 的最长公共前缀: query_min(rank[x],rank[y]1),rank[x]<rank[y]query\_min(rank[x],rank[y]-1), 这里假设rank[x]<rank[y]

 大体思路:对于当前字符串 BiB_i 的每个位置 j (1jm)j \ (1 \le j \le m) , 找到从 jj 开始的最长的子串,且这个子串在 AA 中出现过。也就是找到一个最长的长度 lenlen,满足 BiB_i 的子串 [j,j+len1][j,j+len-1] 同时也是 AA 的一个子串;然后再就是找以位置 jj 为左端点,长度在 lenlen 之内所有的区间权值和的最大值,也就是 max(k=jlvk),(jlj+len1)\max( \sum_{k=j}^{l} v_k ), (j \le l \le j+len-1) ,这部分可以对数组 vv 的前缀和数组建 stst 表,然后查询区间 [j,j+len1][j,j+len-1] 的最大值。

关于如何找到 BiB_i 以位置 j (1jm)j \ (1 \le j \le m) 开头,在 AA 中出现过的最长子串:

 把 AA 和所有 BiB_i 连在一起,中间用没出现过的字符(例如 '$')分隔,得到的字符串记为 SS。然后对 SS 跑后缀数组。在拼接的过程中,维护每个 BiB_iSS 中的起始下标,和第 ii 个字符对应于原来哪一个串,后面要用。

 跑出来后缀数组后,按字典序遍历所有后缀。借助前面维护的信息,我们可以知道当前后缀对应哪个 BiB_i 或者说对应 AA; 如果当前后缀是对应某个 BiB_i 的,就找到离它最近的,属于 AA 串的后缀,求它们之间的 lcplcp ,这个 lcplcp 就是我们前面要求的那个 lenlen

比如题目样例一:

(左边的三列数字分别代表:字典序大小,sa 的值,lcp 的值)

alt


 字典序第 1818 大和第 2222 大的后缀都是原属于 AA 的后缀,第 1919 大的后缀对应于 B3B_3 从下标 1 开始的后缀,也就是它本身。 18i<19(lcp[i])=4,19i<22(lcp[i])=0\min_{18\le i<19}(lcp[i]) = 4,\min_{{19 \le i < 22} }(lcp[i]) = 0 (为啥 \usderset这里用不了qaq),那么 B3B_3 从 1 开始,长度在 44 以内的子串都在 AA 中出现过,求得这个最长的公共子串长度后,用 st 表求一个权重前缀和的最大值,更新 B3B_3 的答案即可。

有两个需要注意的点:

1、本来求 lcp 的部分我是用先从前到后和从后到前循环一遍,记录每个位置左边第一个和右边第一个属于 AA 的后缀,然后 st 表查询区间 lcp 最小值,但是超时了。后来发现在循环的时候直接维护最小值就行了,不需要构建 st 表。

2、所有 BiB_i 之间可以用 '$' 分隔,但是 AAB1B_1 之间最好再换个,不然下面这种数据,跑出来的 lcp 可能处理起来有点麻烦 alt

最后勉强 700+ms 跑过:

#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int INF = 1e9+7;
const long long INFq = 1e18+7;
const long long mode = 998244353;

const int MAX_N = 1300100;
char s[MAX_N];
///----- SA-IS template -----
int sa[MAX_N], Rank[MAX_N], lcp[MAX_N];
int str[MAX_N<<1], Type[MAX_N<<1], p[MAX_N], cnt[MAX_N], cur[MAX_N];
#define pushS(x) sa[cur[str[x]]--] = x
#define pushL(x) sa[cur[str[x]]++] = x
#define inducedSort(v) fill_n(sa, n, -1); fill_n(cnt, m, 0);                     \
    for (int i = 0; i < n; i++) cnt[str[i]]++;                                   \
    for (int i = 1; i < m; i++) cnt[i] += cnt[i-1];                              \
    for (int i = 0; i < m; i++) cur[i] = cnt[i]-1;                               \
    for (int i = n1-1; ~i; i--) pushS(v[i]);                                     \
    for (int i = 1; i < m; i++) cur[i] = cnt[i-1];                               \
    for (int i = 0; i < n; i++) if (sa[i] > 0 &&  Type[sa[i]-1]) pushL(sa[i]-1); \
    for (int i = 0; i < m; i++) cur[i] = cnt[i]-1;                               \
    for (int i = n-1;  ~i; i--) if (sa[i] > 0 && !Type[sa[i]-1]) pushS(sa[i]-1)
void sais(int n, int m, int *str, int *Type, int *p) {
    int n1 = Type[n-1] = 0, ch = Rank[0] = -1, *s1 = str+n;
    for (int i = n-2; ~i; i--) Type[i] = str[i] == str[i+1] ? Type[i+1] : str[i] > str[i+1];
    for (int i = 1; i < n; i++) Rank[i] = Type[i-1] && !Type[i] ? (p[n1] = i, n1++) : -1;
    inducedSort(p);
    for (int i = 0, x, y; i < n; i++) if (~(x = Rank[sa[i]])) {
        if (ch < 1 || p[x+1] - p[x] != p[y+1] - p[y]) ch++;
        else for (int j = p[x], k = p[y]; j <= p[x+1]; j++, k++)
            if ((str[j]<<1|Type[j]) != (str[k]<<1|Type[k])) {ch++; break;}
        s1[y = x] = ch;
    }
    if (ch+1 < n1) sais(n1, ch+1, s1, Type+n, p+n1);
    else for (int i = 0; i < n1; i++) sa[s1[i]] = i;
    for (int i = 0; i < n1; i++) s1[i] = p[sa[i]];
    inducedSort(s1);
}

int mapCharToInt(int n) {
    int m = *max_element(s, s+n);
    fill_n(Rank, m+1, 0);
    for (int i = 0; i < n; i++) Rank[s[i]] = 1;
    for (int i = 0; i < m; i++) Rank[i+1] += Rank[i];
    for (int i = 0; i < n; i++) str[i] = Rank[s[i]] - 1;
    return Rank[m];
}

void SuffixArray(int n) {
    // s[n] 一定要比 s 中所有字符 ascii 值小, s[n+1] 倒无所谓
    s[n] = '!';  s[n+1]='\0';
    int m = mapCharToInt(++n);
    sais(n, m, str, Type, p);
    for (int i = 0; i < n; i++) Rank[sa[i]] = i;
    for (int i = 0, h = lcp[0] = 0; i < n-1; i++) {
        int j = sa[Rank[i]-1];
        while (i+h < n && j+h < n && s[i+h] == s[j+h]) h++;
        if (lcp[Rank[i]-1] = h) h--;
    }
    s[n]='\0';
}
///----- End of SA-IS -----

long long st2[100010][20];
int lg[100010];
long long pref[100010];
void construct_st2(int n) {
    for(int i=1;i<=n;i++)st2[i][0] = pref[i];

    for(int k=1,len=2; len<=n; len*=2,k++) {
        for(int i=1;i+len-1<=n;i++) {
            st2[i][k] = max( st2[i][k-1], st2[i+len/2][k-1] );
        }
    }
}
inline long long query2(int x,int y) {
    int k = lg[y-x+1];
    return max( st2[x][k], st2[y-(1<<k)+1][k] );
}


int v[100010];
int start_pos[100010];
long long ans[100010];
int Map[MAX_N];

int main() {
    ios::sync_with_stdio(false);
    lg[1] = 0;
    for(int i=2;i<=100000;i++) lg[i] = lg[i >> 1] + 1; 

    int n,m,k;
    cin >> n >> m >> k;
    cin >> s;
    for(int i=1;i<=m;i++) cin >> v[i];

    int tot_len = n-1;
    for(int i=0;i<n;i++) Map[i] = 0; // Map 用来映射 s[i] 对应于原来那个串,0 就是 A;  
    for(int i=1;i<=k;i++) {
        ++ tot_len;
        s[tot_len] = '$'; Map[tot_len] = -1; // -1 代表是分隔符;

        start_pos[i] = tot_len+1; // 记录开始位置
        cin >> ( s + tot_len + 1 );

        for(int j=tot_len+1; j<=tot_len+m; j++) Map[j] = i; // 代表 s[j] 原属于 B_i
        tot_len += m;
    }

    s[n] = '#'; // A 和 B_1 之间用 '#' 而非 '$' 

    ++ tot_len;
    SuffixArray(tot_len ); // 板子传入的参数是字符串的长度,下标从 0 开始, tot_len 是 '\0' 的位置

    // s[tot_len] = '\0';
    // cout << "s = " << s << '\n';
    // for(int i=0;i<=tot_len;i++) {
    //     printf("%3d %3d %3d  %s\n",i,sa[i],lcp[i],s+sa[i]);
    // }
    // cout << '\n';

    // 构建前缀和
    pref[0] = 0;
    for(int i=1;i<=m;i++) pref[i] = pref[i-1] + v[i];
    // 前缀和的区间最大值; 为啥是 st2?因为原本有个(多余的) st 用来求 lcp, 但超时了
    construct_st2(m);

    int Min = 0;
    for(int i=1;i<=tot_len;i++) { // 从左到右遍历一边, 用每个后缀左边第一个属于 A 的后缀更新答案
        int j = Map[ sa[i] ]; // sa[i] 代表字典序第 i 大的后缀在原串的起始位置,再用 Map 映射到原来对应的串
        if( j == 0 ) {
            Min = lcp[i]; // 是 A 的后缀,则重置 Min
        }
        else {
            if( j > 0 && Min > 0 ) { // 对应 B_j 的某个后缀
                int index = sa[i] - start_pos[j] + 1; // index 是对应的 B_j 的那个后缀的起始下标
                long long Max = query2( index, index + Min - 1 ); // 查询区间 pref 最大值
                ans[j] = max( ans[j] , Max - pref[index-1] ); // 更新答案
            }
            Min = min( Min, lcp[i] );
        }
    }

    Min = 0;
    for(int i=tot_len;i>0;i--) { // 从右到左遍历,用每个后缀右边第一个属于 A 的后缀更新答案,几乎一样的
        int j = Map[ sa[i] ];
        if( j == 0 ) {
            Min = lcp[i-1];
        }
        else {
            if( j > 0 && Min > 0 ) {
                int index = sa[i] - start_pos[j] + 1;
                long long Max = query2( index, index + Min - 1 );
                ans[j] = max( ans[j] , Max - pref[index-1] );
            }
            Min = min( Min, lcp[i-1] );
        }
    }
    for(int i=1;i<=k;i++) cout << ans[i] << '\n';
}

后缀自动机做法

和后缀数组的思路是一样的,不过这里对于字符串 BiB_i 的每个位置 j (1jm)j \ (1 \le j \le m) ,是找以 jj 结尾的最长的在 AA 中出现过的子串,后面查询的也是 [jlen, j1][j-len,\ j-1] 之间前缀和的最小值。

怎么找:

AA 建立后缀自动机后,记当前的 BiB_iTT (这样我能少打一个下标qwq),在 AA 的自动机上跑匹配,假设 TT 串第 i1i-1 的位置在 AA 的自动机上匹配的最大子串长度为 max_lenmax\_len,对应自动机上的节点为 last_poslast\_pos,那么以 T[i]T[i] 结尾的串肯定是某个以 T[i1]T[i-1] 结尾的串后面加上字符 T[i]T[i],我们就从 last_poslast\_pos 开始在 parentparent 树中向上转移,直到遇到第一个存在字符 T[i]T[i] 的出边的节点位置,这个过程中记录 max_lenmax\_len ,最后 +1 就是 ii 的答案。

嗯,自己写的自己都看不懂写的什么东西。 还是看图吧

假设 AAbcdabcbcdabcTTabcdabcd,开始 last_poslast\_pos 设为 11 ,代表根节点, max_len=0max\_len = 0,因为根节点对应的子串为空串,AA 的自动机长这样子:(每个节点块最后一行{}里的是该节点的 endposendpos 集——节点代表的子串在原串中的结束位置;黑色的边是 parent 的边,蓝色带箭头的是自动机的转移边,旁边的字母是对应的出边的类型;黑边上也有字母是因为 parent 的边和自动机的边重了;len 代表当前节点所代表的子串的最大长度)

alt

每个节点对应的子串:

alt

首先是 a,正好 11 号节点有 a 的出边,走到 5 号节点,max_lenmax\_len++,最大长度为 11

然后是 b55 号节点有 b 的出边,走到 66 号节点,max_lenmax\_len++,最大长度为 22

然后是 c66 号节点有 c 的出边,走到 77 号节点,max_lenmax\_len++,最大长度为 33

然后是 d77 号节点没有 d 的出边,沿 parentparent 数向上走,走到 33 号节点有 d 的出边,走到 44 号节点,max_len=len[3]+1=3max\_len = len[3]+1 = 3

#include<iostream>
#include<cstdio>
#include<vector> 
#include<cstring> 
using namespace std;

const int MAX_N = 100010;

int par[MAX_N<<1], sam[MAX_N<<1][26],len[MAX_N<<1];
int last,tot;

void sam_extend(int ch) {
    int p = last;
    tot++;
    int np = last = tot;
    len[np] = len[p] + 1;

    while( p>0 && sam[p][ch]==0 ){
        sam[p][ch] = np;
        p = par[p];
    }

    if( p==0 ){
        par[np] = 1;
    }
    else{
        int q = sam[p][ch];
        if( len[q] == len[p]+1 )par[np] = q;
        else{
            tot++;
            int nq = tot;
            len[nq] = len[p]+1;
            par[nq] = par[q];
            for(int i=0;i<26;i++)sam[nq][i] = sam[q][i];
            par[np] = par[q] = nq;

            while( p>0 && sam[p][ch]==q ){
                sam[p][ch] = nq;
                p = par[p];
            }
        }
    }
}

int last_pos, max_len;
void Go(int ch) {
    int p = last_pos;

    while( p > 0 && sam[p][ch] == 0 ) {
        p = par[p];
        max_len = len[p];
    }

    if( p == 0 ) {
        // 如果 1 号根节点都没有 ch 的出边,说明字符 ch 在字符串中不存在
        last_pos = 1;
    }
    else {
        int q = sam[p][ch]; // 沿着出边走出去
        ++ max_len; // 就是当前的最大长度
        last_pos = q;
    }
}

long long pref[100010];
long long st[MAX_N][20];
int lg[MAX_N];
void construct_st(int n) {
    for(int i=0;i<=n;i++)st[i][0] = pref[i];

    for(int k=1,len=2; len<=n; len*=2,k++) {
        for(int i=0;i+len-1<=n;i++) {
            st[i][k] = min( st[i][k-1], st[i+len/2][k-1] );
        }
    }
}

long long query(int left,int right) {
    int k = lg[right-left+1];
    return min( st[left][k], st[right-(1<<k)+1][k] );
}

char s[100010], t[100010];
int v[100010];
long long ans[100010];

void print_sam(){
    vector<int>edge[20];
    for(int i=2;i<=tot;i++)edge[par[i]].push_back(i);
    for(int i=1;i<=tot;i++) {
        printf("child %d :",i); for(int u : edge[i])printf(" %d",u); printf("\n");
    }

    for(int i=1;i<=tot;i++) {
        printf("sam %d :\n",i);
        for(int j=0;j<26;j++) {
            if( sam[i][j] > 0 ) {
                printf("  %c -> %d\n",'a'+j,sam[i][j]);
            }
        }
    }
}

int main() {
    // cin.tie(nullptr) -> sync_with_stdio(false);
    lg[1] = 0;
    for(int i=2;i<=100000;i++) lg[i] = lg[i >> 1] + 1; 

    int n,m,k;
    cin >> n >> m >> k;
    cin >> (s+1);
    for(int i=1;i<=m;i++) cin >> v[i];

    last = tot = 1;
    for(int i=1;i<=n;i++) sam_extend(s[i] - 'a');


    pref[0] = 0;
    for(int i=1;i<=m;i++) pref[i] = pref[i-1] + v[i];
    construct_st(m);

    for(int j=1; j<=k; j++) {
        cin >> (t+1);
        last_pos = 1;
        max_len = 0;
        for(int i=1;i<=m;i++) {
            Go(t[i] - 'a'); // 在 parent 树上沿着 last_pos 向上找到第一个有出边 t[i] 的节点
            
            if( max_len > 0 )
                ans[j] = max( ans[j], pref[i] - query(i-max_len,i) ); 
            
        }
    }

    for(int i=1;i<=k;i++) cout << ans[i] << '\n';
}

最后闲扯些给不了解后缀自动机:

 沿着 parentparent 树向下走,相当于在左边添加字符,而越长的子串在原串中的的出现位置相对更少。为什么说 “从 last_poslast\_pos 开始在 parentparent 树中向上转移,直到遇到第一个存在字符 T[i]T[i] 的出边的节点位置”,因为向上走,相当于不断去掉左边的字符,越短的子串在原串的出现的位置相对更多,更“可能”会遇到一个后面跟着一个字符 T[i]T[i] 的位置。

 而沿着 samsam 的出边转移,相当于在子串的后面添加字符。