CF 528D. Fuzzy Search NTT

题目大意

给出文本串S和模式串T和k,S,T为DNA序列(只含ATGC)。对于S中的每个位置\(i\),只要中[i-k,i+k]有一个位置匹配了字符\(i\),那么就认为\(i\)可以匹配。求S中有多少位置匹配了T。

思路

一共有四个字母,我们分别计算每个字母是否可行,其他不管。
最后四个都满足的位置就是一个合法位置(指的是初始位置)。
设g[i]表示S_i位置是否是枚举的字母,f[i]表示M_i是否是是枚举的字母。
他们满足条件只需要右斜对角线==len
发现每个点又是右斜对角线,反转ntt

错误

有点zz,4写成了m,还忘记删掉调试了,wrong了两发,ntt真好调试(不用调试)。

代码

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+7,mod=998244353;
int read() {
    int x=0,f=1;char s=getchar();
    for(;s>'9'||s<'0';s=getchar()) if(s=='-') f=-1;
    for(;s>='0'&&s<='9';s=getchar()) x=x*10+s-'0';
    return x*f;
}
int n,m,k,limit=1,l,r[N];
char S[N],T[N];
int q_pow(int a,int b) {
    int ans=1;
    while(b) {
        if(b&1) ans=1LL*ans*a%mod;
        a=1LL*a*a%mod;
        b>>=1;
    }
    return ans;
}
void ntt(int *a,int type) {
    for(int i=0;i<=limit;++i)
        if(i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<limit;mid<<=1) {
        int Wn=q_pow(3,(mod-1)/(mid<<1));
        for(int i=0;i<limit;i+=(mid<<1)) {
            for(int j=0,w=1;j<mid;j++,w=1LL*w*Wn%mod) {
                int x=a[i+j],y=1LL*w*a[i+j+mid]%mod;
                a[i+j]=(x+y)%mod;
                a[i+j+mid]=(x+mod-y)%mod;
            }
        }
    }
    if(type==-1) {
        reverse(&a[1],&a[limit]);
        int inv=q_pow(limit,mod-2);
        for(int i=0;i<=limit;++i) a[i]=1LL*a[i]*inv%mod;
    }
}
int AAA[N],f[N],g[N],sum[N];
void solve(char x) {
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    int tong,h=0,d=0;
    sum[0]=(S[0]==x);
    for(int i=1;i<n;++i) sum[i]=sum[i-1]+(S[i]==x);
    for(int i=0;i<n;++i) {
        int y=i+k>=n ? sum[n-1] : sum[i+k];
        int x=i-k-1 < 0 ? 0 : sum[i-k-1];
        g[i]=(bool)(y-x);
    }
    for(int i=0;i<m;++i) f[m-i-1]=(T[i]==x);
    // for(int i=0;i<n;++i) cout<<g[i]<<" ";cout<<"\n";
    // for(int i=0;i<m;++i) cout<<f[i]<<" ";cout<<"\n";

    ntt(g,1),ntt(f,1);
    for(int i=0;i<=limit;++i) f[i]=1LL*f[i]*g[i]%mod;
    ntt(f,-1);
    int gs=0;
    for(int i=0;i<m;++i) gs+=(T[i]==x);
    for(int i=m-1,js=0;i<=n-1;++i,++js) AAA[js]+=(f[i]==gs);
}
int main() {
    n=read(),m=read(),k=read();
    scanf("%s%s",S,T);
    while(limit<=n+m-2) limit<<=1,l++;
    for(int i=0;i<=limit;++i)
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    solve('A'),solve('T'),solve('G'),solve('C');
    int tot=0;
    for(int i=0;i<=n-m+1;++i) tot+=(AAA[i]==4);
    printf("%d\n",tot);
    return 0;
}