看注释
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
//如果在不支持 avx2 的平台上将 avx2 换成 avx 或 SSE 之一
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef vector<string> VS;
typedef vector<int> VI;
typedef vector<vector<int>> VVI;
vector<int> vx;
inline void divide() {sort(vx.begin(),vx.end());vx.erase(unique(vx.begin(),vx.end()),vx.end());}
inline int mp(int x) {return upper_bound(vx.begin(),vx.end(),x)-vx.begin();}
inline int log_2(int x) {return 31-__builtin_clz(x);}
inline int popcount(int x) {return __builtin_popcount(x);}
inline int lowbit(int x) {return x&-x;}
inline ll Lsqrt(ll x) { ll L = 1,R = 2e9;while(L + 1 < R){ll M = (L+R)/2;if(M*M <= x) L = M;else R = M;}return L;}
inline ll cal(int x)
{
return (ll)x * (x + 1) / 2;
}
void solve()
{
//用所有子串的数目减去不包含完整子串的数目
int n, m;
cin>>n>>m;
string s, t;
cin>>s>>t;
vector<int> pos;
auto KMP = [&](string s, string t) -> void
{
vector<int> nxt(m + 1);
s = '-' + s;
t = '-' + t;
for (int i = 2, j = 0; i <= m; i++)
{
while (j && t[i] != t[j + 1]) j = nxt[j];
if (t[i] == t[j + 1]) j++;
nxt[i] = j;
}
for (int i = 1, j = 0; i <= n; i++)
{
while (j && s[i] != t[j + 1]) j = nxt[j];
if (s[i] == t[j + 1]) j++;
if (j == m)
{
j = nxt[j];
pos.push_back(i - m + 1);
}
}
};
KMP(s, t);
ll res = 0;
//计算以i为开头的串
for(int i = 1; i <= n; ++i)
{
auto it = lower_bound(pos.begin(), pos.end(), i);
if(it == pos.end()) break;
int tail = *it + m - 1;
res += n - tail + 1;
}
cout<<res<<'\n';
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int T = 1;
//cin>>T;
while(T--)
{
solve();
}
}