I 题题意:定义一个串的最小表示为——按出现某个字符第一次出现顺序从小到大的排序,并依次编号 abczabc\cdots z。给定串 SS,对其全部后缀进行最小表示的排序。S2×105|S| \leq 2\times 10^5

解法:考虑朴素的后缀数组排序,即是对后缀 i,ji,j 的最小表示排序。如果对于每对比较可以在 logn\log n 的复杂度内可以完成,那么总的复杂度就为 O(n2n)\mathcal O(n \log^2n)。对于 logn\log n 的字符串比较,通常可以转化为求 lcp\rm lcp 长度,再比较接下来的第一位的大小。

首先可以记忆每个字符所在的位置,以 O(Σnlogn)\mathcal O(|\Sigma|n \log n) 的复杂度内计算出每个位置的最小表示中的字符映射关系。考虑如何求 lcp\rm lcp 的长度,朴素的实现是依次枚举位置 i,ji,j 的最小表示中每一位的字符,看它们从第 i,ji,j 位开始只考虑当前字符的最长匹配长度。即:

chi,k:100001001001后缀ichj,k:010001001010后缀jch_{i,k}:1000\underbrace{01001001\cdots}_{\text{后缀}_i}\\ ch_{j,k}:0100\underbrace{01001010\cdots}_{\text{后缀}_j}\\

由上图,即是比较后缀 i,ji,j 的最小表示中的第 kk 个字符,然后仅考虑对应字符 chi,kch_{i,k}chj,kch_{j,k} 的出现位置,其余字符忽略,单纯考虑它们的 lcp\rm lcp。例如在上图中,这两个后缀在第 kk 个字符(可能 chi,kchj,kch_{i,k} \neq ch_{j,k})的 lcp\rm lcp 就是 66。那么这个可以通过记录每一种字符出现位置的哈希值来做到 O(logn)\mathcal O(\log n) 的计算 lcp\rm lcp。然后取每一种字符的 lcp\rm lcp 的最小值为这一对的 lcp\rm lcp。这样的朴素实现是 O(Σlogn)\mathcal O(|\Sigma|\log n),不符合要求。

一个有力的优化是,每次二分的长度越来越短。即第二次二分的长度和是否进行二分在第一次的基础之上进行。具体来说,进行三个优化:

  1. chi,kch_{i,k}chj,kch_{j,k} 的第一个出现字符已经超出当前二分的最大值 rr,则可以终止字符枚举。这是由于字符映射是根据字符出现的第一个位置递增顺序排序,若第 kk 个字符出现位置都超过 rr,则更后面的字符出现位置只会更靠后,而对于它们的比较中,后缀 i,ji,j 在第 kk 个字符的串的 [i,i+r1][i,i+r-1][j,j+r1][j,j+r-1] 范围里全是 00(未出现),则必然相同,没必要继续比较了。
  2. 此外,首先比较 [i,i+r1][i,i+r-1][j,j+r1][j,j+r-1] 的哈希值,若相同则当前位置没必要进行二分。
  3. 根据上一位的 rr 来确定本次的二分上界 rr

因而这样可以将复杂度降到 O(max(Σ),logn)\mathcal O(\max(|\Sigma|),\log n)。因而利用 sort 的优秀常数,足以以 O(nlognmax(logn,Σ))\mathcal O(n \log n \max(\log n,|\Sigma|)) 通过本题。

#include <bits/stdc++.h>
using namespace std;
const int N = 200000;
const long long base = 2, mod = 1000000007;
long long power(long long a, long long x)
{
    long long ans = 1;
    while(x)
    {
        if (x & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        x >>= 1;
    }
    return ans;
}
long long inv(long long a)
{
    return power(a, mod - 2);
}
long long th[N + 5], invth[N + 5];
long long h[26][N + 5];
char s[N + 5];
long long get_hash(int id, int l, int r)
{
    return (h[id][r] - h[id][l - 1] + mod) % mod * invth[l - 1] % mod;
}
vector<int> pos[26];
int chrk[N + 5][26], n;
bool cmp(int a, int b)
{
    int *t1 = chrk[a], *t2 = chrk[b];
    int left = 2, right = n - max(a, b) + 1;
    for (int i = 0; i < 26 && right > 1;i++)
    {
        int l = left, r = right;
        if (t1[i] - a >= right && t2[i] - b >= right)
            break;
        if (t1[i] - a != t2[i] - b)
        {
            right = min(t1[i] - a, t2[i] - b);
            break;
        }
        int ida = s[t1[i]] - 97, idb = s[t2[i]] - 97;
        if (get_hash(ida, a, a + right - 1) == get_hash(idb, b, b + right - 1))
            continue;
        while (l <= r)
        {
            int mid = (l + r) >> 1;
            if (get_hash(ida, a, a + mid - 1) == get_hash(idb, b, b + mid - 1))
            {
                right = mid;
                l = mid + 1;
            }
            else
                r = mid - 1;
        }
    }
    if (max(a, b) + right > n)
        return a > b;
    vector<int> rk1(26, 0), rk2(26, 0);
    for (int i = 0; i < 26; i++)
    {
        if (t1[i] <= n)
            rk1[s[t1[i]] - 97] = i;
        if (t2[i] <= n)
            rk2[s[t2[i]] - 97] = i;
    }
    return rk1[s[a + right] - 97] < rk2[s[b + right] - 97];
}
int ans[N + 5];
int main()
{
    th[0] = invth[0] = 1;
    for (int i = 1; i <= N;i++)
        th[i] = th[i - 1] * base % mod;
    invth[N] = inv(th[N]);
    for (int i = N - 1; i >= 1;i--)
        invth[i] = invth[i + 1] * base % mod;
    scanf("%d%s", &n, s + 1);
    for (int i = 1; i <= n;i++)
    {
        ans[i] = i;
        pos[s[i] - 97].push_back(i);
        for (int j = 0; j < 26;j++)
            h[j][i] = h[j][i - 1];
        h[s[i] - 97][i] = (h[s[i] - 97][i] + th[i]) % mod;
    }
    for (int i = 0; i < 26;i++)
        pos[i].push_back(n + 1);
    for (int i = 1; i <= n;i++)
    {
        vector<int> temp(26);
        for (int j = 0; j < 26;j++)
            temp[j] = *lower_bound(pos[j].begin(), pos[j].end(), i);
        sort(temp.begin(), temp.end());
        for (int j = 0; j < 26;j++)
            chrk[i][j] = temp[j];
    }
    sort(ans + 1, ans + n + 1, cmp);
    for (int i = 1; i <= n;i++)
        printf("%d ", ans[i]);
    return 0;
}