看注释

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
//如果在不支持 avx2 的平台上将 avx2 换成 avx 或 SSE 之一
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
typedef pair<int,int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef vector<string> VS;
typedef vector<int> VI;
typedef vector<vector<int>> VVI;
vector<int> vx;
inline void divide() {sort(vx.begin(),vx.end());vx.erase(unique(vx.begin(),vx.end()),vx.end());}
inline int mp(int x) {return upper_bound(vx.begin(),vx.end(),x)-vx.begin();}
inline int log_2(int x) {return 31-__builtin_clz(x);}
inline int popcount(int x) {return __builtin_popcount(x);}
inline int lowbit(int x) {return x&-x;}
inline ll Lsqrt(ll x) { ll L = 1,R = 2e9;while(L + 1 < R){ll M = (L+R)/2;if(M*M <= x) L = M;else R = M;}return L;}
inline ll cal(int x)
{
    return (ll)x * (x + 1) / 2;
}
void solve()
{
    //用所有子串的数目减去不包含完整子串的数目
    int n, m;
    cin>>n>>m;
    string s, t;
    cin>>s>>t;
    vector<int> pos;
    auto KMP = [&](string s, string t) -> void 
    {
        vector<int> nxt(m + 1);
        s = '-' + s;
        t = '-' + t;
        for (int i = 2, j = 0; i <= m; i++) 
        {
            while (j && t[i] != t[j + 1]) j = nxt[j];
            if (t[i] == t[j + 1]) j++;
            nxt[i] = j;
        }
        for (int i = 1, j = 0; i <= n; i++) 
        {
            while (j && s[i] != t[j + 1]) j = nxt[j];
            if (s[i] == t[j + 1]) j++;
            if (j == m) 
            {
                j = nxt[j];
                pos.push_back(i - m + 1);
            }
        }
    };
    KMP(s, t);
    ll res = 0;
    //计算以i为开头的串
    for(int i = 1; i <= n; ++i)
    {
        auto it = lower_bound(pos.begin(), pos.end(), i);
        if(it == pos.end()) break;
        int tail = *it + m - 1;
        res += n - tail + 1;
    }
    cout<<res<<'\n';
}
int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	int T = 1;
	//cin>>T;
	while(T--)
	{
		solve();
	}
}