2025 牛客多校 Round 5 B Extra Training

大概题意

给定一个字符串 , 每次可以选择 的一个子串(可以是空串)拼接, 问 次拼接可以形成多少种不同的串。
范围:

分析

显然不能够直接拼接 次,得考虑使用矩阵快速幂加速递推。
我们需要对拼接出来的字符串去重,所以我们可以选择一个字符串最早被拼出来的那一次进行计算。 如果我们要计算一个字符串被最早拼出来的一次,我们必须要对选择来拼接的子串做出限制以达到我们的目标。
一个好理解的结论 : 如果一个子串 加上一个字符 后仍然属于原串的子串,那么显然这不会是一个 "最快的" 方案。
那么我们如果做出如下限制: 对于我们在串后拼接的子串,只能选择符合"最快"方案的,即如果我们当前拼接出来的串中,最后一个子串为 , 且定义 的接受字符集 , 那么我们只能选择以 中字符开头的子串进行拼接。
再考虑设置一个结束状态 ,任何字符开头的子串可以后接 结束状态
这个结束状态可以理解为空串,开始拼接空串即代表构造结束。
于是我们拼接一个子串 只受到开头字符的限制,并在拼接后变成 的限制。
在后缀自动机上,我们可以在DAG上得到每个endpos集合的第一个字符的集合,也可以得到它们后接哪些字符可以转移,哪些不能。
于是我们设计如下的递推状态转移:
表示当前拼接出来的串中,以 字符开头, 且可以后接 字符的串的个数。其中也包含结束状态字符。
表示 中以 字符开头,且可以后接 字符的子串的个数。
滚动递推,有:

完全就是矩阵乘法的形式!
于是我们可以将这个转移写成矩阵递推的形式,使用矩阵快速幂优化。
初始的dp状态是一个初始字符 ,可以后接任意字符开头的串。所以是

下面我们想办法利用后缀自动机求出转移矩阵
特殊的,我们在结束状态只能拼接一个空串,此时结束状态仍然保持,所以有
一般的,我们要求出所有以 字符开头,且可以后接 字符的子串的个数。
如果直接先序遍历SAM的DAG,在转移的时候记录每个节点以每个字符开头的子串数,计数是 的,因为后缀自动机的转移边只有 条。
然而,后续统计是麻烦的。对于每个不能转移的边,我们得花费 的时间将贡献加到 上,这是 的。而这个复杂度的代码将会超时。(虽然哥哥的代码卡过去了)
可能的代码如下:

    vector<array<int,V>>cnt(m);
    //cnt[i][j] ->node is i, first char is j  cnt
    vector<int>indeg(m);//tmp for topo sort
    for(int i = 1;i < m;i++){
        for(int j = 0;j < M;j++){
            if(t[i].nxt[j] != 0){
                indeg[t[i].nxt[j]]++;
            }
        }
    }
    // init first char from root, when tranverse, first is keep 
    for(int i = 0;i < M;i++){
        if(t[1].nxt[i] != 0)
            cnt[t[1].nxt[i]][i]++;
    }
    queue<int>q;
    q.push(1);
    while(!q.empty()){//topo order tranverse
        int p = q.front();
        q.pop();
        for(int i = 0;i < M;i++){
            
            int s = t[p].nxt[i];//first cases
            if(s){
                cnt[s] += cnt[p];
                indeg[s]--;
                if(indeg[s] == 0)q.push(s);
            }
            else{
                for(int j = 0;j < M;j++){//O(nV^2)
                    tr[j][i] += cnt[p][j];
                }
            }
        }
        for(int j = 0;j < M;j++){
            tr[j][V - 1] += cnt[p][j];
        }
    }
    print(tr);

这是因为我们预处理以每个字符开头的串后,将其均摊到每个节点上去计算了。也就是我们是通过DAG转移开头字符信息在节点上计算可否后接字符信息
但是如果我们在DAG上转移可否后接字符信息,在节点上计算开头字符信息,情况是否会变得不一样呢?
使用这种做法,我们使用后序遍历来遍历DAG,计算是否可以转移时,如果不能,我们加到该节点的转移计数数组上,而这大概有 次。 然后,我们通过转移边,将其计数信息 转移到父节点上(当然可能有多个,这代表前接一个字符),而转移边有条,这部分的复杂度也是 的。
最后,我们只需要通过根节点访问以每个字符开头的"根节点",遍历可以后接的字符计数数组,加入到 中,这部分时间复杂度是 。 在这种做法下,时间复杂度被优化到了 ,而我们仅仅是换了一种遍历和统计的方式,就将 的复杂度拆分到 上去。
这种改变统计方式从而改变时间复杂度的trick也是我攻克这道题的卡点之一,其根源大概是SAM的DAG中,可以转移的边有条,不能的有 条,而我们的计数操作有 两种?而转移起来均为 , 所以要正确搭配计数转移和转移边?(感觉可以分析出来这里的不同,但是还说不上透彻,对这里产生差异的根源不够清楚)
于是我们可以得到正确时间复杂度计算的代码:

    matrix tr;
    tr[V - 1][V - 1] = 1;

    //tr[i][j] -> begin is i, can link j
    vector<array<ll,V>>cnt(m);
    //cnt[i][j], begin from node i,  can link char c 'cnt
    //attention the difference 
    vector<int>vs(m);
    auto dfs = [&](auto&&self, int p)->void
    {
        if(vs[p])return;
        vs[p] = true;
        cnt[p][V - 1] = 1;//initial 
        for(int i = 0;i < M;i++){
            int s = t[p].nxt[i];
            if(s){                
                self(self, s);
                //scnt[p] += scnt[s];
                cnt[p] += cnt[s]; //recalc
                //cnt[p][i] += scnt[s];
            }
            else{
                cnt[p][i]++;
            }
        }
    };
    dfs(dfs, 1);


    for(int i = 0;i < M;i++){
        int p = t[1].nxt[i];
        if(p){
            for(int j = 0;j < V;j++){
                tr[i][j] = cnt[p][j];
            }
        }
    }

得到了转移矩阵 ,我们做一遍矩阵快速幂就好。
这部分时间复杂度

统计答案的部分,我们直接计算从初始字符开始转移,以可以转移到结束状态的串的个数即可。
为什么是结束状态?
结束状态是一个"吸收态",代表着所有的合法转移的个数(合法转移都可以转移到结束态),而如果直接计算转移到各个字符结尾的串的加和,显然会重复计数。
初始状态字符 的设置是随意的,可以是不同的,只需要每个字符可以被转移到,仅用于计数,可以单开一维,甚至可以取结束状态。
于是可以全部设置成同一个字符 转移到全部字符,最后查询 到终止状态的方案数即可。
所以,将初始状态初始化为单位矩阵,最后计算每个字符转移到终止状态的和也是对的。
甚至可以随便设置,只需要保证每个字符被转移到一次即可,最后不重不漏地统计就行。
所以你这样初始化都居然都没问题:

    //matrix res = get_diag(1); // ok
    mt19937 g(chrono::steady_clock::now().time_since_epoch().count());
    const int pi = g() % V;
    vector<int>row(V),col(V);
    iota(row.begin(),row.end(),0);
    iota(col.begin(),col.end(),0);
    shuffle(row.begin(),row.end(),g);
    shuffle(col.begin(),col.end(),g);
    matrix res{};
    for(int i = 0;i < V;i++){
        //res[pi][i] = 1;//correct ,calc can use res[pi][V-1] as ans
        res[row[i]][col[i]] = 1; //correct
        //res[g() % V][col[i]] = 1; //correct
        //res[row[i]][g() % V] = 1; //wrong,can not assure col
    }

    while(k){
        if(k & 1)res = res * tr;
        tr = tr * tr;
        k >>= 1;
    }

    Z ans = 0;
    //ans = res[pi][V - 1];
    for(int i = 0;i < V;i++){
        ans += res[i][V - 1];
    }
    cout<<ans<<endl;
    return;

完整代码如下:

#include<bits/stdc++.h>
using i64 = long long;
using ll = long long;
using uint = unsigned int;
using ull = unsigned long long;

using namespace std;

constexpr int M = 52;//size 
constexpr int V = M + 1;
struct SAM
{
    struct Node{
        int len;
        int link;
        array<int,M>nxt;
        Node():len{},link{},nxt{}{}
    };
    vector<Node>t;
    vector<vector<int>>ot;
    vector<int>endpos_size;
    
    int last = 1;

    SAM(){
        init();
    }

    SAM(string& s){
        init();
        build(s);
    }

    void init()
    {
        t.assign(2,Node());
        t[0].len = -1;
        t[0].nxt.fill(1);
    }

    int newNode()
    {
        t.push_back(Node());
        return t.size() - 1;
    }

    static inline int num(int x)
    {
        if(islower(x))return x - 'a';
        else return x - 'A' + 26;
        //return x - 'a';
    }

    void extend(int x)
    {
        int cur = newNode();
        t[cur].len = t[last].len + 1;
        int p = last;
        while(p != 0 && t[p].nxt[x] == 0){
            t[p].nxt[x] = cur;
            p = t[p].link;
        }
        int q = t[p].nxt[x];
        if(p == 0){
            t[cur].link = 1; 
        }
        else if(t[q].len == t[p].len + 1){
            t[cur].link = q; 
        }
        else{
            int clone = newNode();
            t[clone].link = t[q].link;
            t[clone].nxt = t[q].nxt;
            t[clone].len = t[p].len + 1;
            t[cur].link = clone;
            t[q].link = clone;
            while(p != 0 && t[p].nxt[x] == q){
                t[p].nxt[x] = clone;
                p = t[p].link;
            }
        }
        last = cur;
        return;
    }

    void build(string& s)
    {
        for(auto &c : s){
            extend(num(c));
        }
        //get_out_linktree();
    }

    inline int nxt(int p, int x)
    {
        return t[p].nxt[x];
    }

    void get_out_linktree()
    {
        ot.resize(t.size());
        for(int i = 2;i < t.size();i++){
            ot[t[i].link].push_back(i);
        }
    }

    void calc_endpos_size(string &s)
    {
        endpos_size.resize(t.size());
        int p = 1;
        for(auto c:s){
            p = t[p].nxt[num(c)];
            endpos_size[p]++;
        }
        auto dfs = [&](auto&&self, int p)->void
        {
            for(auto s : ot[p]){
                self(self, s);
                endpos_size[p] += endpos_size[s];
            }
        };
        dfs(dfs,1);
        endpos_size[1] = 1;
    }
};

template<class T>
constexpr T power(T a, i64 b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
constexpr i64 mul(i64 a, i64 b, i64 p) {
    i64 res = a * b - i64(1.L * a * b / p) * p;
    res %= p;
    if (res < 0) {
        res += p;
    }
    return res;
}

template<int P>
struct MInt {
    int x;
    constexpr MInt() : x{} {}
    constexpr MInt(i64 x) : x{norm(x % getMod())} {}

    static int Mod;
    constexpr static int getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(int Mod_) {
        Mod = Mod_;
    }
    constexpr int norm(int x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr int val() const {
        return x;
    }
    explicit constexpr operator int() const {
        return x;
    }
    constexpr MInt operator-() const {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) & {
        x = 1LL * x * rhs.x % getMod();
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MInt operator*(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr istream &operator>>(istream &is, MInt &a) {
        i64 v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr ostream &operator<<(ostream &os, const MInt &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MInt lhs, MInt rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MInt lhs, MInt rhs) {
        return lhs.val() != rhs.val();
    }
};
//template<>
//int MInt<0>::Mod = 998244353;
//int MInt<0>::Mod = 1000000007;
constexpr int P = 998244353;
//constexpr int P = 1000000007;
using Z = MInt<P>;


using matrix = array<array<Z,V>,V>;

matrix operator*(const matrix& lhs, const matrix& rhs)
{
    matrix res{};
    for(int i = 0;i < V;i++){
        for(int j = 0;j < V;j++){
            for(int k = 0;k < V;k++){
                res[i][j] += lhs[i][k] * rhs[k][j];
            }
        }
    }
    return res;
}

matrix get_diag(int x)
{
    matrix m;
    for(int i = 0;i < V;i++){
        m[i][i] = x;
    }
    return m;
}

template<typename T>
array<T, V> operator+(const array<T,V>&lhs, const array<T,V>&rhs)
{
    array<T,V>res{};
    for(int i = 0;i < V;i++){
        res[i] = lhs[i] + rhs[i];
    }
    return res;
}

template<typename T>
array<T, V>& operator+=(array<T,V>&lhs, const array<T,V>&rhs)
{
    for(int i = 0;i < V;i++){
        lhs[i] += rhs[i];
    }
    return lhs;
}

void print(matrix m)
{
    for(int i = 0;i < V;i++){
        for(int j = 0;j < V;j++){
            cerr<<m[i][j]<<" ";
        }
        cerr<<endl;
    }
    cerr<<endl;
}

// 3 0 3 
// 2 1 2 
// 0 0 1 

// 3 1 3 
// 2 1 2 
// 0 0 1 

void solve()
{
    int n,k;
    cin>>n>>k;
    string s;
    cin>>s;
    SAM sam(s);
    auto &t = sam.t;
    int m = t.size();
    matrix tr;
    tr[V - 1][V - 1] = 1;

    //tr[i][j] -> begin is i, can link j ?

    //vector<ll>scnt(m, 1);
    vector<array<ll,V>>cnt(m);
    vector<int>vs(m);
    auto dfs = [&](auto&&self, int p)->void
    {
        if(vs[p])return;
        vs[p] = true;
        cnt[p][V - 1] = 1;//initial 
        for(int i = 0;i < M;i++){
            int s = t[p].nxt[i];
            if(s){                
                self(self, s);
                cnt[p] += cnt[s]; //recalc
            }
            else{
                cnt[p][i]++;// ? 
            }
        }
    };
    dfs(dfs, 1);


    for(int i = 0;i < M;i++){
        int p = t[1].nxt[i];
        if(p){
            for(int j = 0;j < V;j++){
                tr[i][j] = cnt[p][j];
            }
        }
    }

    //print(tr);

    // vector<array<int,V>>cnt(m);
    // vector<int>indeg(m);
    // for(int i = 1;i < m;i++){
    //     for(int j = 0;j < M;j++){
    //         if(t[i].nxt[j] != 0){
    //             indeg[t[i].nxt[j]]++;
    //         }
    //     }
    // }
    // for(int i = 0;i < M;i++){
    //     if(t[1].nxt[i] != 0)
    //         cnt[t[1].nxt[i]][i]++;
    // }
    // queue<int>q;
    // q.push(1);
    // while(!q.empty()){
    //     int p = q.front();
    //     q.pop();
    //     for(int i = 0;i < M;i++){
            
    //         int s = t[p].nxt[i];//first cases
    //         if(s){
    //             cnt[s] += cnt[p];
    //             indeg[s]--;
    //             if(indeg[s] == 0)q.push(s);
    //         }
    //         else{
    //             for(int j = 0;j < M;j++){//O(nV^2)
    //                 tr[j][i] += cnt[p][j];
    //             }
    //         }
    //     }
    //     for(int j = 0;j < M;j++){
    //         tr[j][V - 1] += cnt[p][j];
    //     }
    // }
    // print(tr);


    //matrix res = get_diag(1);
    mt19937 g(chrono::steady_clock::now().time_since_epoch().count());
    const int pi = g() % V;
    matrix res{};
    for(int i = 0;i < V;i++){
        res[pi][i] = 1;
    }

    // print(tr);
    // print(res);

    while(k){
        if(k & 1)res = res * tr;
        tr = tr * tr;
        k >>= 1;
    }

    //print(res);

    Z ans = 0;
    ans = res[pi][V - 1];
    // for(int i = 0;i < V;i++){
    //     ans += res[i][V - 1];
    //     //ans += res[V - 1][i];
    // }
    cout<<ans<<endl;
    return;
}

int main()
{
    std::ios::sync_with_stdio(0);
    std::cin.tie(0);
    int tt = 1;
    //cin >> tt;
    while(tt--){
        solve();
    }
    return 0;
}