All with Pairs

题目链接

题目大意

给n个字符串,f(s1,s2) 代表 s1的前缀和s2的后缀相同的最长长度。

求上面的那个式子。

题解

如果算的不是最长长度而是所有的话,就直接先hash一下后缀,在每个前缀里找有多少个后缀与这个前缀相等就可以。
但是这个题要求最大值,并不是所有的,那么怎么去除除了最大值的呢?
比如 两个字符串 aba 与 aba
a 与 a 算一次 ,aba 与 aba 又算一次。
所以得删除掉a算的那次,怎么去掉那次呢?
先看一下什么情况要删除吧
也就是 算前缀的时候从前往后遍历,如果遇到一个 就把他前面的删掉,也就是 一个前缀k 要把k 的最长的前缀跟k最长后缀相等的就把那个前缀的 减去这个点。

例如这个,如果以j为结尾的前缀是下面字符串的后缀,那么 根据kmp中next的定义以next[j] - 1结尾的也肯定是它的后缀,所以应该减去。
不能直接清零,应该在next[j] - 1的地方减去j 匹配的数量,因为j匹配的next[j] - 1一定可以匹配,但是next[j] - 1匹配的地方j不一定可以匹配。
记得取模!!!!

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <cmath>
#include <set>
#include <cstring>
#include <string>
#include <bitset>
#include <stdlib.h>
#include <time.h>
using namespace std;
typedef long long ll;
typedef pair<int,ll> pii;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void wenjian(){
   freopen("concatenation.in","r",stdin);freopen("concatenation.out","w",stdout);}
void tempwj(){
   freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
   return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
   a %= mod;ll ans = 1;while(b){
   if(b & 1)ans = ans * a % mod;a = a * a % mod;b >>= 1;}return ans;}
struct cmp{
   bool operator()(const pii & a, const pii & b){
   return a.second > b.second;}};
int lb(int x){
   return  x & -x;}
const int INF = 0x3f3f3f3f;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int maxn = 1e5 + 4;
const int M = 4e4+2;
const ll mod = 998244353;
string a[maxn];
map<ull,ll> mm;
ull seed = 131;
void chuli(string a)
{
   
    int m = a.length();
    ull temp = 1;
    ull t = 0;
    for (int i = m - 1; i >= 0; i -- )
    {
   
        t += (a[i] - 'a' + 1) * temp;
        temp *= seed;
        mm[t] ++;
    }
}
int nex[maxn];
int ans[maxn];
int main()
{
   
    int n;
    scanf("%d",&n);
    for (int i = 1; i <= n; i ++ )
    {
   
        cin>>a[i];
        chuli(a[i]);
    }
// cout<<"111"<<endl;
    ll res = 0;
    for (int i = 1; i <= n; i ++ )
    {
   
        int m = a[i].length();
        nex[0] = 0;
        int k = 0;
        for (int j = 1; j < m; j ++ )
        {
   
            while(k && a[i][j] != a[i][k])
            {
   
// cout<<k<<endl;
                k = nex[k - 1];
            }
            if(a[i][j] == a[i][k])
                k ++ ;
            nex[j] = k;
        }
// for (int j = 0; j < m; j ++ )
// printf("%d ",nex[j]);
// printf("\n");
        ull t = 0;
        for (int j = 0; j < m; j ++ )
        {
   
            t = t * seed + (a[i][j] - 'a' + 1);
            ans[j] = mm[t];
            if(nex[j])
                ans[nex[j] - 1] -= ans[j];
        }
        for (int j = 0; j < m; j ++ )
        {
   
// cout<<ans[j]<<" ";
            res += 1ll * (j + 1) * (j + 1) % mod * ans[j]% mod;
            res %= mod;
        }
// cout<<endl;
    }
    cout<<res<<endl;
}