All with Pairs
题目大意
给n个字符串,f(s1,s2) 代表 s1的前缀和s2的后缀相同的最长长度。
求上面的那个式子。
题解
如果算的不是最长长度而是所有的话,就直接先hash一下后缀,在每个前缀里找有多少个后缀与这个前缀相等就可以。
但是这个题要求最大值,并不是所有的,那么怎么去除除了最大值的呢?
比如 两个字符串 aba 与 aba
a 与 a 算一次 ,aba 与 aba 又算一次。
所以得删除掉a算的那次,怎么去掉那次呢?
先看一下什么情况要删除吧
也就是 算前缀的时候从前往后遍历,如果遇到一个 就把他前面的删掉,也就是 一个前缀k 要把k 的最长的前缀跟k最长后缀相等的就把那个前缀的 减去这个点。
例如这个,如果以j为结尾的前缀是下面字符串的后缀,那么 根据kmp中next的定义以next[j] - 1结尾的也肯定是它的后缀,所以应该减去。
不能直接清零,应该在next[j] - 1的地方减去j 匹配的数量,因为j匹配的next[j] - 1一定可以匹配,但是next[j] - 1匹配的地方j不一定可以匹配。
记得取模!!!!
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <cmath>
#include <set>
#include <cstring>
#include <string>
#include <bitset>
#include <stdlib.h>
#include <time.h>
using namespace std;
typedef long long ll;
typedef pair<int,ll> pii;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void wenjian(){
freopen("concatenation.in","r",stdin);freopen("concatenation.out","w",stdout);}
void tempwj(){
freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
a %= mod;ll ans = 1;while(b){
if(b & 1)ans = ans * a % mod;a = a * a % mod;b >>= 1;}return ans;}
struct cmp{
bool operator()(const pii & a, const pii & b){
return a.second > b.second;}};
int lb(int x){
return x & -x;}
const int INF = 0x3f3f3f3f;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int maxn = 1e5 + 4;
const int M = 4e4+2;
const ll mod = 998244353;
string a[maxn];
map<ull,ll> mm;
ull seed = 131;
void chuli(string a)
{
int m = a.length();
ull temp = 1;
ull t = 0;
for (int i = m - 1; i >= 0; i -- )
{
t += (a[i] - 'a' + 1) * temp;
temp *= seed;
mm[t] ++;
}
}
int nex[maxn];
int ans[maxn];
int main()
{
int n;
scanf("%d",&n);
for (int i = 1; i <= n; i ++ )
{
cin>>a[i];
chuli(a[i]);
}
// cout<<"111"<<endl;
ll res = 0;
for (int i = 1; i <= n; i ++ )
{
int m = a[i].length();
nex[0] = 0;
int k = 0;
for (int j = 1; j < m; j ++ )
{
while(k && a[i][j] != a[i][k])
{
// cout<<k<<endl;
k = nex[k - 1];
}
if(a[i][j] == a[i][k])
k ++ ;
nex[j] = k;
}
// for (int j = 0; j < m; j ++ )
// printf("%d ",nex[j]);
// printf("\n");
ull t = 0;
for (int j = 0; j < m; j ++ )
{
t = t * seed + (a[i][j] - 'a' + 1);
ans[j] = mm[t];
if(nex[j])
ans[nex[j] - 1] -= ans[j];
}
for (int j = 0; j < m; j ++ )
{
// cout<<ans[j]<<" ";
res += 1ll * (j + 1) * (j + 1) % mod * ans[j]% mod;
res %= mod;
}
// cout<<endl;
}
cout<<res<<endl;
}