牛客33194G多校 - Magic Spells

题意

  • 给出 KK 个字符串,求所有字符串的公共回文子串的个数。
  • K5K \leq 5n3×105\sum n \leq 3\times 10^5

思路

错误思路

  • 先用 Manacher 处理出第一个字符串的回文串,从左到右分别计算以每个位置为中心的每个回文串的哈希,放入 set 中。
  • 错误原因:复杂度太大。

正确思路

  • 先用 Manacher 处理出第一个字符串的回文串,从左到右分别计算以每个位置为中心的每个回文串的哈希,
  • 对于每个位置,定义两个指针 llrr,初始 llrr 位于以当前位置为中心的最长回文串的左右端点。
  • llrr 不断向内聚拢,每次计算 hash[l,r]\text{hash}[l,r] 并放入 set 中。
  • 如果发现 hash[l,r]\text{hash}[l,r] 已经存在于 set 中,就停止,看下一个位置。
  • 我们能观察到,以上做法复杂度为 O(nlogn)O(n\text{log}n)
  • 将所有字符串都进行类似的操作,存到不同的 set 中,统计即可。

代码

#include <cstdio>
#include <iostream>
#include <cstring>
#include <map>
#include <set>
#include <algorithm>
#define int long long
const int N		= 3e5+10;
const int MOD	= 1e9+7;
const int BASE1 = 131;
const int BASE2 = 1331;
using namespace std;

struct MANACHER
{
    char str1[N];//原串
    char str2[N*3];//转换后的串(转换后的串长度为 2n+2 ),形态:abc -> ^$a$b$c$&
    int p[N*3];// pi - 1:以 i/2 为中心的最长回文串长度,除2可以是小数,如“3333”的p5(即原串第2.5位)就 = 5,再-1,就是真正的长度
    int n;
    int len;
	
	pair<int,int> pi[N];
    
    int change()
    {
        n=strlen(str1+1);
        
        str2[1]='$';
        int tot=1;
        for (int i=1;i<=n;i++)
        {
            str2[++tot]=str1[i];
            str2[++tot]='$';
        }
        str2[++tot]='&';
        return tot;
    }
    
    int Manacher()
    {
        len=change();
        int mid=1,mx=1,ans=-1;//ans = 最长的回文串长度
        for (int i=1;i<=len;i++)
        {
            if (i<mx)
                p[i]=min(mx-i,p[mid*2-i]);
            else
                p[i]=1;
            while (str2[i-p[i]]==str2[i+p[i]])
                p[i]++;
            if (mx<i+p[i])
            {
                mid=i;
                mx=i+p[i];
            }
            ans=max(ans,p[i]-1);
        }
        
		for (int i=1;i<=len;i++)
			pi[i] = {p[i],i};
        return ans;
    }
};


int bin1[N], bin2[N];

void Init()
{
	bin1[0] = bin2[0] = 1;
	for (int i=1; i<N; i++)
	{
		bin1[i] = bin1[i-1] * BASE1 % MOD;
		bin2[i] = bin2[i-1] * BASE2 % MOD;
	}
}

struct HASH
{
	char str[N];//1-idx
	int hash1[N],hash2[N];
	int n;
	void Init()
	{
		n=strlen(str+1);
		for (int i=1; i<=n; i++)//开long long
		{
			hash1[i] = hash1[i-1] * BASE1 % MOD + str[i]-'a'+1, hash1[i] %= MOD;
			hash2[i] = hash2[i-1] * BASE2 % MOD + str[i]-'a'+1, hash2[i] %= MOD;
		}
	}
	int Hash1(int l,int r)
	{
		int len = r-l+1;
		return (hash1[r] - hash1[l-1] * bin1[len] % MOD + MOD) % MOD ;
	}
	
	int Hash2(int l,int r)
	{
		int len = r-l+1;
		return (hash2[r] - hash2[l-1] * bin2[len] % MOD + MOD) % MOD ;
	}
};

set< pair<int,int> > st[7];

HASH ha[6];
MANACHER mcr[6];

int K;

void Sol()
{
	for (int i=1; i<=K; i++)
	{
		ha[i].Init();
		mcr[i].Manacher();
	}
	
	
	for (int i=1; i<=mcr[1].len; i++)
	{
		if(mcr[1].pi[i].first-1 <= 0)
			continue;
		
		int l = (mcr[1].pi[i].second/2) - (mcr[1].pi[i].first-1)/2 + (mcr[1].pi[i].second%2);
		int r = (mcr[1].pi[i].second/2) + (mcr[1].pi[i].first-1)/2;
		
		while (l<=r)
		{
			int h1 = ha[1].Hash1(l, r), h2 = ha[1].Hash2(l, r);
			if(st[1].find({h1, h2})!=st[1].end())
				break;
			
			st[1].insert({h1, h2});
			l++, r--;
		}
	}
	
	if(K==1)
	{
		printf("%lu\n",st[1].size());
		return;
	}
	
	for (int k=2; k<=K; k++)
	{
		for (int i=1; i<=mcr[k].len; i++)
		{
			if(mcr[k].pi[i].first-1 <= 0)
				continue;
			
			int l = (mcr[k].pi[i].second/2) - (mcr[k].pi[i].first-1)/2 + (mcr[k].pi[i].second%2);
			int r = (mcr[k].pi[i].second/2) + (mcr[k].pi[i].first-1)/2;
			
			while (l<=r)
			{
				int h1 = ha[k].Hash1(l, r), h2 = ha[k].Hash2(l, r);
				if(st[k].find({h1, h2})!=st[k].end())
					break;
				
				st[k].insert({h1,h2});
				l++, r--;
			}
		}
	}
	
	
	int ans=0;
	for (auto item : st[1])
	{
		for (int i=2; i<=K; i++)
		{
			if(st[i].find(item)==st[i].end())
				break;
			else if(i==K)ans++;
		}
	}
	
	printf("%lld\n",ans);
}


signed main()
{
	Init();
	scanf("%lld",&K);
	for (int i=1; i<=K; i++)
	{
		scanf("%s",ha[i].str+1);
		strcpy(mcr[i].str1+1, ha[i].str+1);
	}
	Sol();
	return 0;
}