题目考点:KMP
题目大意:给定n个字符串,对于每一个字符串,计算出其在n个字符串中出现的次数的乘积
普通(超时)思路:O(n^2)进行KMP
for(int i = 0; i < n; i++)
{
int ans = 1, cnt = 0;
for(int j = 0; j < n; j++)
{
cnt = kmp(r[i], r[j]);
ans = ans * cnt % mod;
}
cout << ans << '\n';
}
毫无疑问T了
我们仔细思考一下,当前字符串S在别的字符串出现,当且仅当满足S比别的字符串都要短时才可行的对吧?也就是说,我们只需要考虑长度最短的字符串去和别的字符串做KMP就可以了
明白上一点后,我们显然可以得到一个特殊情况:如果最短长度的字符串出现多个,那么对于所有的字符串都输出0就可以了。
代码:
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
#define js std::ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
typedef long long LL;
const int N = 2000010, mod = 998244353;
int n, ne[N];
string s;
vector<string> r;
string minn; int len = 0x3f3f3f3f; // 记录长度最短的字符串
void getne(string s1)
{
ne[0] = -1; // 下标从0开始,ne[0]为-1
int len1 = s1.size();
for(int i = 1, j = -1; i < len1; i++)
{
while(j > -1 && s1[i] != s1[j+1]) j = ne[j];
if(s1[i] == s1[j+1]) j++;
ne[i] = j;
}
}
LL match(string s1, string s2)
{
LL cnt = 0;
int len1 = s1.size();
int len2 = s2.size();
for(int i = 0, j = -1; i < len2; i++)
{
while(j > -1 && s2[i] != s1[j+1]) j = ne[j];
if(s2[i] == s1[j+1]) j++;
if(j == len1-1)
{
cnt ++;
j = ne[len1-1];
}
}
return cnt;
}
int main()
{
js;
cin >> n;
for(int i = 0; i < n; i++)
{
cin >> s; r.push_back(s);
if(len > s.size())
minn = s, len = s.size(); // 维护长度最短的字符串
}
for(int i = 0; i < n; i++) // 如果有多个长度最短但不相同的字符串的情况
{
if(r[i].size() == len && r[i] != minn)
{
for(int i = 0; i < n; i++) cout << "0\n";
return 0;
}
}
getne(minn); // 求出minn的next数组
LL ans = 1;
for(int i = 0; i < n; i++)
{
LL cnt = match(minn, r[i]);
ans = ans * cnt % mod;
}
for(int i = 0; i < n; i++)
{
if(r[i].size() != len) cout << "0\n";
else cout << ans << '\n'; // 只有长度最短的字符串不是0
}
return 0;
}