牛客33194G多校 - Magic Spells
题意
- 给出 K 个字符串,求所有字符串的公共回文子串的个数。
- K≤5,∑n≤3×105
思路
错误思路
- 先用 Manacher 处理出第一个字符串的回文串,从左到右分别计算以每个位置为中心的每个回文串的哈希,放入 set 中。
- 错误原因:复杂度太大。
正确思路
- 先用 Manacher 处理出第一个字符串的回文串,从左到右分别计算以每个位置为中心的每个回文串的哈希,
- 对于每个位置,定义两个指针 l 和 r,初始 l 和 r 位于以当前位置为中心的最长回文串的左右端点。
- l 和 r 不断向内聚拢,每次计算 hash[l,r] 并放入 set 中。
- 如果发现 hash[l,r] 已经存在于 set 中,就停止,看下一个位置。
- 我们能观察到,以上做法复杂度为 O(nlogn)
- 将所有字符串都进行类似的操作,存到不同的 set 中,统计即可。
代码
#include <cstdio>
#include <iostream>
#include <cstring>
#include <map>
#include <set>
#include <algorithm>
#define int long long
const int N = 3e5+10;
const int MOD = 1e9+7;
const int BASE1 = 131;
const int BASE2 = 1331;
using namespace std;
struct MANACHER
{
char str1[N];//原串
char str2[N*3];//转换后的串(转换后的串长度为 2n+2 ),形态:abc -> ^$a$b$c$&
int p[N*3];// pi - 1:以 i/2 为中心的最长回文串长度,除2可以是小数,如“3333”的p5(即原串第2.5位)就 = 5,再-1,就是真正的长度
int n;
int len;
pair<int,int> pi[N];
int change()
{
n=strlen(str1+1);
str2[1]='$';
int tot=1;
for (int i=1;i<=n;i++)
{
str2[++tot]=str1[i];
str2[++tot]='$';
}
str2[++tot]='&';
return tot;
}
int Manacher()
{
len=change();
int mid=1,mx=1,ans=-1;//ans = 最长的回文串长度
for (int i=1;i<=len;i++)
{
if (i<mx)
p[i]=min(mx-i,p[mid*2-i]);
else
p[i]=1;
while (str2[i-p[i]]==str2[i+p[i]])
p[i]++;
if (mx<i+p[i])
{
mid=i;
mx=i+p[i];
}
ans=max(ans,p[i]-1);
}
for (int i=1;i<=len;i++)
pi[i] = {p[i],i};
return ans;
}
};
int bin1[N], bin2[N];
void Init()
{
bin1[0] = bin2[0] = 1;
for (int i=1; i<N; i++)
{
bin1[i] = bin1[i-1] * BASE1 % MOD;
bin2[i] = bin2[i-1] * BASE2 % MOD;
}
}
struct HASH
{
char str[N];//1-idx
int hash1[N],hash2[N];
int n;
void Init()
{
n=strlen(str+1);
for (int i=1; i<=n; i++)//开long long
{
hash1[i] = hash1[i-1] * BASE1 % MOD + str[i]-'a'+1, hash1[i] %= MOD;
hash2[i] = hash2[i-1] * BASE2 % MOD + str[i]-'a'+1, hash2[i] %= MOD;
}
}
int Hash1(int l,int r)
{
int len = r-l+1;
return (hash1[r] - hash1[l-1] * bin1[len] % MOD + MOD) % MOD ;
}
int Hash2(int l,int r)
{
int len = r-l+1;
return (hash2[r] - hash2[l-1] * bin2[len] % MOD + MOD) % MOD ;
}
};
set< pair<int,int> > st[7];
HASH ha[6];
MANACHER mcr[6];
int K;
void Sol()
{
for (int i=1; i<=K; i++)
{
ha[i].Init();
mcr[i].Manacher();
}
for (int i=1; i<=mcr[1].len; i++)
{
if(mcr[1].pi[i].first-1 <= 0)
continue;
int l = (mcr[1].pi[i].second/2) - (mcr[1].pi[i].first-1)/2 + (mcr[1].pi[i].second%2);
int r = (mcr[1].pi[i].second/2) + (mcr[1].pi[i].first-1)/2;
while (l<=r)
{
int h1 = ha[1].Hash1(l, r), h2 = ha[1].Hash2(l, r);
if(st[1].find({h1, h2})!=st[1].end())
break;
st[1].insert({h1, h2});
l++, r--;
}
}
if(K==1)
{
printf("%lu\n",st[1].size());
return;
}
for (int k=2; k<=K; k++)
{
for (int i=1; i<=mcr[k].len; i++)
{
if(mcr[k].pi[i].first-1 <= 0)
continue;
int l = (mcr[k].pi[i].second/2) - (mcr[k].pi[i].first-1)/2 + (mcr[k].pi[i].second%2);
int r = (mcr[k].pi[i].second/2) + (mcr[k].pi[i].first-1)/2;
while (l<=r)
{
int h1 = ha[k].Hash1(l, r), h2 = ha[k].Hash2(l, r);
if(st[k].find({h1, h2})!=st[k].end())
break;
st[k].insert({h1,h2});
l++, r--;
}
}
}
int ans=0;
for (auto item : st[1])
{
for (int i=2; i<=K; i++)
{
if(st[i].find(item)==st[i].end())
break;
else if(i==K)ans++;
}
}
printf("%lld\n",ans);
}
signed main()
{
Init();
scanf("%lld",&K);
for (int i=1; i<=K; i++)
{
scanf("%s",ha[i].str+1);
strcpy(mcr[i].str1+1, ha[i].str+1);
}
Sol();
return 0;
}