2025 牛客多校 Round 5 B Extra Training
大概题意
给定一个字符串 , 每次可以选择
的一个子串(可以是空串)拼接, 问
次拼接可以形成多少种不同的串。
范围:
分析
显然不能够直接拼接 次,得考虑使用矩阵快速幂加速递推。
我们需要对拼接出来的字符串去重,所以我们可以选择一个字符串最早被拼出来的那一次进行计算。
如果我们要计算一个字符串被最早拼出来的一次,我们必须要对选择来拼接的子串做出限制以达到我们的目标。
一个好理解的结论 : 如果一个子串 加上一个字符
后仍然属于原串的子串,那么显然这不会是一个 "最快的" 方案。
那么我们如果做出如下限制: 对于我们在串后拼接的子串,只能选择符合"最快"方案的,即如果我们当前拼接出来的串中,最后一个子串为 , 且定义
的接受字符集
, 那么我们只能选择以
中字符开头的子串进行拼接。
再考虑设置一个结束状态 ,任何字符开头的子串可以后接 结束状态
。
这个结束状态可以理解为空串,开始拼接空串即代表构造结束。
于是我们拼接一个子串 只受到开头字符的限制,并在拼接后变成
的限制。
在后缀自动机上,我们可以在DAG上得到每个endpos集合的第一个字符的集合,也可以得到它们后接哪些字符可以转移,哪些不能。
于是我们设计如下的递推状态转移:
表示当前拼接出来的串中,以
字符开头, 且可以后接
字符的串的个数。其中也包含结束状态字符。
表示
中以
字符开头,且可以后接
字符的子串的个数。
滚动递推,有:
完全就是矩阵乘法的形式!
于是我们可以将这个转移写成矩阵递推的形式,使用矩阵快速幂优化。
初始的dp状态是一个初始字符 ,可以后接任意字符开头的串。所以是
下面我们想办法利用后缀自动机求出转移矩阵
特殊的,我们在结束状态只能拼接一个空串,此时结束状态仍然保持,所以有
一般的,我们要求出所有以 字符开头,且可以后接
字符的子串的个数。
如果直接先序遍历SAM的DAG,在转移的时候记录每个节点以每个字符开头的子串数,计数是 的,因为后缀自动机的转移边只有
条。
然而,后续统计是麻烦的。对于每个不能转移的边,我们得花费 的时间将贡献加到
上,这是
的。而这个复杂度的代码将会超时。(虽然哥哥的代码卡过去了)
可能的代码如下:
vector<array<int,V>>cnt(m);
//cnt[i][j] ->node is i, first char is j cnt
vector<int>indeg(m);//tmp for topo sort
for(int i = 1;i < m;i++){
for(int j = 0;j < M;j++){
if(t[i].nxt[j] != 0){
indeg[t[i].nxt[j]]++;
}
}
}
// init first char from root, when tranverse, first is keep
for(int i = 0;i < M;i++){
if(t[1].nxt[i] != 0)
cnt[t[1].nxt[i]][i]++;
}
queue<int>q;
q.push(1);
while(!q.empty()){//topo order tranverse
int p = q.front();
q.pop();
for(int i = 0;i < M;i++){
int s = t[p].nxt[i];//first cases
if(s){
cnt[s] += cnt[p];
indeg[s]--;
if(indeg[s] == 0)q.push(s);
}
else{
for(int j = 0;j < M;j++){//O(nV^2)
tr[j][i] += cnt[p][j];
}
}
}
for(int j = 0;j < M;j++){
tr[j][V - 1] += cnt[p][j];
}
}
print(tr);
这是因为我们预处理以每个字符开头的串后,将其均摊到每个节点上去计算了。也就是我们是通过DAG转移开头字符信息,在节点上计算可否后接字符信息。
但是如果我们在DAG上转移可否后接字符信息,在节点上计算开头字符信息,情况是否会变得不一样呢?
使用这种做法,我们使用后序遍历来遍历DAG,计算是否可以转移时,如果不能,我们加到该节点的转移计数数组上,而这大概有
次。
然后,我们通过转移边,将其计数信息
转移到父节点上(当然可能有多个,这代表前接一个字符),而转移边有
条,这部分的复杂度也是
的。
最后,我们只需要通过根节点访问以每个字符开头的"根节点",遍历可以后接的字符计数数组,加入到 中,这部分时间复杂度是
。
在这种做法下,时间复杂度被优化到了
,而我们仅仅是换了一种遍历和统计的方式,就将
的复杂度拆分到
和
上去。
这种改变统计方式从而改变时间复杂度的trick也是我攻克这道题的卡点之一,其根源大概是SAM的DAG中,可以转移的边有条,不能的有
条,而我们的计数操作有
和
两种?而转移起来均为
, 所以要正确搭配计数转移和转移边?(感觉可以分析出来这里的不同,但是还说不上透彻,对这里产生差异的根源不够清楚)
于是我们可以得到正确时间复杂度计算的代码:
matrix tr;
tr[V - 1][V - 1] = 1;
//tr[i][j] -> begin is i, can link j
vector<array<ll,V>>cnt(m);
//cnt[i][j], begin from node i, can link char c 'cnt
//attention the difference
vector<int>vs(m);
auto dfs = [&](auto&&self, int p)->void
{
if(vs[p])return;
vs[p] = true;
cnt[p][V - 1] = 1;//initial
for(int i = 0;i < M;i++){
int s = t[p].nxt[i];
if(s){
self(self, s);
//scnt[p] += scnt[s];
cnt[p] += cnt[s]; //recalc
//cnt[p][i] += scnt[s];
}
else{
cnt[p][i]++;
}
}
};
dfs(dfs, 1);
for(int i = 0;i < M;i++){
int p = t[1].nxt[i];
if(p){
for(int j = 0;j < V;j++){
tr[i][j] = cnt[p][j];
}
}
}
得到了转移矩阵 ,我们做一遍矩阵快速幂就好。
这部分时间复杂度
统计答案的部分,我们直接计算从初始字符开始转移,以可以转移到结束状态的串的个数即可。
为什么是结束状态?
结束状态是一个"吸收态",代表着所有的合法转移的个数(合法转移都可以转移到结束态),而如果直接计算转移到各个字符结尾的串的加和,显然会重复计数。
初始状态字符 的设置是随意的,可以是不同的,只需要每个字符可以被转移到,仅用于计数,可以单开一维,甚至可以取结束状态。
于是可以全部设置成同一个字符 转移到全部字符,最后查询
到终止状态的方案数即可。
所以,将初始状态初始化为单位矩阵,最后计算每个字符转移到终止状态的和也是对的。
甚至可以随便设置,只需要保证每个字符被转移到一次即可,最后不重不漏地统计就行。
所以你这样初始化都居然都没问题:
//matrix res = get_diag(1); // ok
mt19937 g(chrono::steady_clock::now().time_since_epoch().count());
const int pi = g() % V;
vector<int>row(V),col(V);
iota(row.begin(),row.end(),0);
iota(col.begin(),col.end(),0);
shuffle(row.begin(),row.end(),g);
shuffle(col.begin(),col.end(),g);
matrix res{};
for(int i = 0;i < V;i++){
//res[pi][i] = 1;//correct ,calc can use res[pi][V-1] as ans
res[row[i]][col[i]] = 1; //correct
//res[g() % V][col[i]] = 1; //correct
//res[row[i]][g() % V] = 1; //wrong,can not assure col
}
while(k){
if(k & 1)res = res * tr;
tr = tr * tr;
k >>= 1;
}
Z ans = 0;
//ans = res[pi][V - 1];
for(int i = 0;i < V;i++){
ans += res[i][V - 1];
}
cout<<ans<<endl;
return;
完整代码如下:
#include<bits/stdc++.h>
using i64 = long long;
using ll = long long;
using uint = unsigned int;
using ull = unsigned long long;
using namespace std;
constexpr int M = 52;//size
constexpr int V = M + 1;
struct SAM
{
struct Node{
int len;
int link;
array<int,M>nxt;
Node():len{},link{},nxt{}{}
};
vector<Node>t;
vector<vector<int>>ot;
vector<int>endpos_size;
int last = 1;
SAM(){
init();
}
SAM(string& s){
init();
build(s);
}
void init()
{
t.assign(2,Node());
t[0].len = -1;
t[0].nxt.fill(1);
}
int newNode()
{
t.push_back(Node());
return t.size() - 1;
}
static inline int num(int x)
{
if(islower(x))return x - 'a';
else return x - 'A' + 26;
//return x - 'a';
}
void extend(int x)
{
int cur = newNode();
t[cur].len = t[last].len + 1;
int p = last;
while(p != 0 && t[p].nxt[x] == 0){
t[p].nxt[x] = cur;
p = t[p].link;
}
int q = t[p].nxt[x];
if(p == 0){
t[cur].link = 1;
}
else if(t[q].len == t[p].len + 1){
t[cur].link = q;
}
else{
int clone = newNode();
t[clone].link = t[q].link;
t[clone].nxt = t[q].nxt;
t[clone].len = t[p].len + 1;
t[cur].link = clone;
t[q].link = clone;
while(p != 0 && t[p].nxt[x] == q){
t[p].nxt[x] = clone;
p = t[p].link;
}
}
last = cur;
return;
}
void build(string& s)
{
for(auto &c : s){
extend(num(c));
}
//get_out_linktree();
}
inline int nxt(int p, int x)
{
return t[p].nxt[x];
}
void get_out_linktree()
{
ot.resize(t.size());
for(int i = 2;i < t.size();i++){
ot[t[i].link].push_back(i);
}
}
void calc_endpos_size(string &s)
{
endpos_size.resize(t.size());
int p = 1;
for(auto c:s){
p = t[p].nxt[num(c)];
endpos_size[p]++;
}
auto dfs = [&](auto&&self, int p)->void
{
for(auto s : ot[p]){
self(self, s);
endpos_size[p] += endpos_size[s];
}
};
dfs(dfs,1);
endpos_size[1] = 1;
}
};
template<class T>
constexpr T power(T a, i64 b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
constexpr i64 mul(i64 a, i64 b, i64 p) {
i64 res = a * b - i64(1.L * a * b / p) * p;
res %= p;
if (res < 0) {
res += p;
}
return res;
}
template<int P>
struct MInt {
int x;
constexpr MInt() : x{} {}
constexpr MInt(i64 x) : x{norm(x % getMod())} {}
static int Mod;
constexpr static int getMod() {
if (P > 0) {
return P;
} else {
return Mod;
}
}
constexpr static void setMod(int Mod_) {
Mod = Mod_;
}
constexpr int norm(int x) const {
if (x < 0) {
x += getMod();
}
if (x >= getMod()) {
x -= getMod();
}
return x;
}
constexpr int val() const {
return x;
}
explicit constexpr operator int() const {
return x;
}
constexpr MInt operator-() const {
MInt res;
res.x = norm(getMod() - x);
return res;
}
constexpr MInt inv() const {
assert(x != 0);
return power(*this, getMod() - 2);
}
constexpr MInt &operator*=(MInt rhs) & {
x = 1LL * x * rhs.x % getMod();
return *this;
}
constexpr MInt &operator+=(MInt rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MInt &operator-=(MInt rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MInt &operator/=(MInt rhs) & {
return *this *= rhs.inv();
}
friend constexpr MInt operator*(MInt lhs, MInt rhs) {
MInt res = lhs;
res *= rhs;
return res;
}
friend constexpr MInt operator+(MInt lhs, MInt rhs) {
MInt res = lhs;
res += rhs;
return res;
}
friend constexpr MInt operator-(MInt lhs, MInt rhs) {
MInt res = lhs;
res -= rhs;
return res;
}
friend constexpr MInt operator/(MInt lhs, MInt rhs) {
MInt res = lhs;
res /= rhs;
return res;
}
friend constexpr istream &operator>>(istream &is, MInt &a) {
i64 v;
is >> v;
a = MInt(v);
return is;
}
friend constexpr ostream &operator<<(ostream &os, const MInt &a) {
return os << a.val();
}
friend constexpr bool operator==(MInt lhs, MInt rhs) {
return lhs.val() == rhs.val();
}
friend constexpr bool operator!=(MInt lhs, MInt rhs) {
return lhs.val() != rhs.val();
}
};
//template<>
//int MInt<0>::Mod = 998244353;
//int MInt<0>::Mod = 1000000007;
constexpr int P = 998244353;
//constexpr int P = 1000000007;
using Z = MInt<P>;
using matrix = array<array<Z,V>,V>;
matrix operator*(const matrix& lhs, const matrix& rhs)
{
matrix res{};
for(int i = 0;i < V;i++){
for(int j = 0;j < V;j++){
for(int k = 0;k < V;k++){
res[i][j] += lhs[i][k] * rhs[k][j];
}
}
}
return res;
}
matrix get_diag(int x)
{
matrix m;
for(int i = 0;i < V;i++){
m[i][i] = x;
}
return m;
}
template<typename T>
array<T, V> operator+(const array<T,V>&lhs, const array<T,V>&rhs)
{
array<T,V>res{};
for(int i = 0;i < V;i++){
res[i] = lhs[i] + rhs[i];
}
return res;
}
template<typename T>
array<T, V>& operator+=(array<T,V>&lhs, const array<T,V>&rhs)
{
for(int i = 0;i < V;i++){
lhs[i] += rhs[i];
}
return lhs;
}
void print(matrix m)
{
for(int i = 0;i < V;i++){
for(int j = 0;j < V;j++){
cerr<<m[i][j]<<" ";
}
cerr<<endl;
}
cerr<<endl;
}
// 3 0 3
// 2 1 2
// 0 0 1
// 3 1 3
// 2 1 2
// 0 0 1
void solve()
{
int n,k;
cin>>n>>k;
string s;
cin>>s;
SAM sam(s);
auto &t = sam.t;
int m = t.size();
matrix tr;
tr[V - 1][V - 1] = 1;
//tr[i][j] -> begin is i, can link j ?
//vector<ll>scnt(m, 1);
vector<array<ll,V>>cnt(m);
vector<int>vs(m);
auto dfs = [&](auto&&self, int p)->void
{
if(vs[p])return;
vs[p] = true;
cnt[p][V - 1] = 1;//initial
for(int i = 0;i < M;i++){
int s = t[p].nxt[i];
if(s){
self(self, s);
cnt[p] += cnt[s]; //recalc
}
else{
cnt[p][i]++;// ?
}
}
};
dfs(dfs, 1);
for(int i = 0;i < M;i++){
int p = t[1].nxt[i];
if(p){
for(int j = 0;j < V;j++){
tr[i][j] = cnt[p][j];
}
}
}
//print(tr);
// vector<array<int,V>>cnt(m);
// vector<int>indeg(m);
// for(int i = 1;i < m;i++){
// for(int j = 0;j < M;j++){
// if(t[i].nxt[j] != 0){
// indeg[t[i].nxt[j]]++;
// }
// }
// }
// for(int i = 0;i < M;i++){
// if(t[1].nxt[i] != 0)
// cnt[t[1].nxt[i]][i]++;
// }
// queue<int>q;
// q.push(1);
// while(!q.empty()){
// int p = q.front();
// q.pop();
// for(int i = 0;i < M;i++){
// int s = t[p].nxt[i];//first cases
// if(s){
// cnt[s] += cnt[p];
// indeg[s]--;
// if(indeg[s] == 0)q.push(s);
// }
// else{
// for(int j = 0;j < M;j++){//O(nV^2)
// tr[j][i] += cnt[p][j];
// }
// }
// }
// for(int j = 0;j < M;j++){
// tr[j][V - 1] += cnt[p][j];
// }
// }
// print(tr);
//matrix res = get_diag(1);
mt19937 g(chrono::steady_clock::now().time_since_epoch().count());
const int pi = g() % V;
matrix res{};
for(int i = 0;i < V;i++){
res[pi][i] = 1;
}
// print(tr);
// print(res);
while(k){
if(k & 1)res = res * tr;
tr = tr * tr;
k >>= 1;
}
//print(res);
Z ans = 0;
ans = res[pi][V - 1];
// for(int i = 0;i < V;i++){
// ans += res[i][V - 1];
// //ans += res[V - 1][i];
// }
cout<<ans<<endl;
return;
}
int main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0);
int tt = 1;
//cin >> tt;
while(tt--){
solve();
}
return 0;
}