一、
这题真的被卡了一晚上,各种数组清零,各种函数调用,本来精神状态就不太好,结果还调了那么久。
这里有一点不一样的地方就是多项式求逆( inv()函数 )里面的 len 关于 n 的取值和以前不太一样,以前是(n>>1) , 这里是(n-1)>>1 才能过。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=8e5+100;
const int p=998244353;
const int mod=998244353;
const int g=3;
int fi[maxn];
int a[maxn],inva[maxn],da[maxn],lna[maxn],c[maxn],iinv[maxn];
int f[maxn];
int n;

void get_inv(int n)
{
    iinv[1]=1;
    for(int i=2;i<=n;i++)
        iinv[i]=1ll*(mod-mod/i)*iinv[mod%i]%mod;
}

int mypow(int a,int b)
{
    if(b<0) return mypow(mypow(a,p-2),-b);
    int ans=1;
    while(b)
    {
        if(b&1) ans=1ll*ans*a%p;
        a=1ll*a*a%p;
        b>>=1;
    }
    return ans%p;
}

int getlen(int n,int m)
{
    int len=1,cnt=0;
    while(len<=n+m) len<<=1,cnt++;
    for(int i=0;i<len;i++)
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
    return len;
}

void ntt(int *x,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<fi[i]) swap(x[i],x[fi[i]]);

    for(int i=1;i<len;i<<=1)
    {
        int r=i<<1;
        int wn=mypow(g,f*(p-1)/r);
        for(int j=0;j<len;j+=r)
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int xx=x[j+k],yy=1ll*w*x[j+i+k]%p;
                x[j+k]=(xx+yy)%p;
                x[j+i+k]=((xx-yy)%p+p)%p;
                w=1ll*w*wn%p;
            }
        }
    }
    if(f==-1)
    {
        int invn=mypow(len,p-2);
        for(int i=0;i<len;i++)
            x[i]=1ll*x[i]*invn%p;
    }
}


void inv(int n,int *a,int *b)
{
    if(n==1)
    {
        b[0]=mypow(a[0],p-2);
        return ;
    }
    inv((n+1)>>1,a,b);
    int len=1,cnt=0;
    while(len<=((n-1)<<1)) len<<=1,cnt++;//这里改成n-1了
    for(int i=0;i<len;i++)
    {
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
        c[i]=(i<n?a[i]:0);
        b[i]=(i<n?b[i]:0);
    }
    ntt(c,len,1);
    ntt(b,len,1);
    for(int i=0;i<len;i++)
        b[i]=(2-(ll)c[i]*b[i]%p+p)%p*b[i]%p;
    ntt(b,len,-1);
    for(int i=n;i<len;i++)
        b[i]=0;
}

void dao(int *a,int *b,int n,int len)
{
    for(int i=1;i<n;i++)
        b[i-1]=(ll)i*a[i]%mod;
    for(int i=n-1;i<len;i++)
        b[i]=0;
}

void ji(int *a,int n,int len)
{
    for(int i=n-2;i>=0;i--)
        a[i+1]=(ll)a[i]*iinv[i+1]%mod;
    a[0]=0;
    for(int i=n;i<len;i++)
        a[i]=0;
}

void ln(int *a,int n)
{
    int len=getlen(n-2,n-1);
    dao(a,da,n,len);
    inv(n-1,a,inva);

    ntt(da,len,1);
    ntt(inva,len,1);
    for(int i=0;i<len;i++) lna[i]=(ll)da[i]*inva[i]%mod;
    ntt(lna,len,-1);

    ji(lna,n,len);
}

void eexp(int *f,int *a,int n)
{
    if(n==1)
    {
        a[0]=1;
        return ;
    }
    eexp(f,a,(n+1)>>1);
    ln(a,n);
    int len=getlen(n-1,n-1);
    for(int i=0;i<n;i++)
        lna[i]=(-lna[i]+f[i]+mod)%mod;

    lna[0]++;
    ntt(a,len,1);
    ntt(lna,len,1);
    for(int i=0;i<len;i++) a[i]=1ll*a[i]*lna[i]%mod;
    ntt(a,len,-1);
    for(int i=n;i<len;i++)
        a[i]=0;

}

int main(void)
{
    scanf("%d",&n);
    get_inv(n);
    for(int i=0;i<n;i++)
        scanf("%d",&f[i]);
    eexp(f,a,n);
    for(int i=0;i<n;i++)
        printf("%d ",a[i]);
    putchar('\n');
    return 0;

}



其实n变成n-1并不是多项式求逆出了问题,是因为我们在每一步中都占用了getlen和fi,导致求ln的时候,如果求inv用n,求完inv之后fi就变了。

对于dao,ji所涉及数组的清空,也可放置ln里面。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=8e5+100;
const int p=998244353;
const int mod=998244353;
const int g=3;
int fi[maxn];
int a[maxn],inva[maxn],da[maxn],lna[maxn],c[maxn],iinv[maxn];
int f[maxn];
int n;

void get_inv(int n)
{
    iinv[1]=1;
    for(int i=2;i<=n;i++)
        iinv[i]=1ll*(mod-mod/i)*iinv[mod%i]%mod;
}

int mypow(int a,int b)
{
    if(b<0) return mypow(mypow(a,p-2),-b);
    int ans=1;
    while(b)
    {
        if(b&1) ans=1ll*ans*a%p;
        a=1ll*a*a%p;
        b>>=1;
    }
    return ans%p;
}

int getlen(int n,int m)
{
    int len=1,cnt=0;
    while(len<=n+m) len<<=1,cnt++;
    for(int i=0;i<len;i++)
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
    return len;
}

void ntt(int *x,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<fi[i]) swap(x[i],x[fi[i]]);

    for(int i=1;i<len;i<<=1)
    {
        int r=i<<1;
        int wn=mypow(g,f*(p-1)/r);
        for(int j=0;j<len;j+=r)
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int xx=x[j+k],yy=1ll*w*x[j+i+k]%p;
                x[j+k]=(xx+yy)%p;
                x[j+i+k]=((xx-yy)%p+p)%p;
                w=1ll*w*wn%p;
            }
        }
    }
    if(f==-1)
    {
        int invn=mypow(len,p-2);
        for(int i=0;i<len;i++)
            x[i]=1ll*x[i]*invn%p;
    }
}


void inv(int n,int *a,int *b)
{
    if(n==1)
    {
        b[0]=mypow(a[0],p-2);
        return ;
    }
    inv((n+1)>>1,a,b);
    int len=1,cnt=0;
    while(len<=(n<<1)) len<<=1,cnt++;
    for(int i=0;i<len;i++)
    {
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
        c[i]=(i<n?a[i]:0);
        b[i]=(i<n?b[i]:0);
    }
    ntt(c,len,1);
    ntt(b,len,1);
    for(int i=0;i<len;i++)
        b[i]=(2-(ll)c[i]*b[i]%p+p)%p*b[i]%p;
    ntt(b,len,-1);
    for(int i=n;i<len;i++)
        b[i]=0;
}

void dao(int *a,int *b,int n,int len)
{
    for(int i=1;i<n;i++)
        b[i-1]=(ll)i*a[i]%mod;
    for(int i=n-1;i<len;i++)
        b[i]=0;
}

void ji(int *a,int n,int len)
{
    for(int i=n-2;i>=0;i--)
        a[i+1]=(ll)a[i]*iinv[i+1]%mod;
    a[0]=0;
    for(int i=n;i<len;i++)
        a[i]=0;
}

void ln(int *a,int n)
{
    int len=getlen(n-2,n-1);
    dao(a,da,n,len);
    inv(n-1,a,inva);
    len=getlen(n-2,n-1);
    ntt(da,len,1);
    ntt(inva,len,1);
    for(int i=0;i<len;i++) lna[i]=(ll)da[i]*inva[i]%mod;
    ntt(lna,len,-1);

    ji(lna,n,len);
}

void eexp(int *f,int *a,int n)
{
    if(n==1)
    {
        a[0]=1;
        return ;
    }
    eexp(f,a,(n+1)>>1);
    ln(a,n);
    int len=getlen(n-1,n-1);
    for(int i=0;i<n;i++)
        lna[i]=(-lna[i]+f[i]+mod)%mod;

    lna[0]++;
    ntt(a,len,1);
    ntt(lna,len,1);
    for(int i=0;i<len;i++) a[i]=1ll*a[i]*lna[i]%mod;
    ntt(a,len,-1);
    for(int i=n;i<len;i++)
        a[i]=0;

}

int main(void)
{
    scanf("%d",&n);
    get_inv(n);
    for(int i=0;i<n;i++)
        scanf("%d",&f[i]);
    eexp(f,a,n);
    for(int i=0;i<n;i++)
        printf("%d ",a[i]);
    putchar('\n');
    return 0;

}




二、
这是今天做学军中学的一道题目,看到这个质数就知道得用 ntt ,可是化了好久也不知道用什么东西来算。
强行莽一发。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=8e5+100;
const int p=950009857;
const int mod=950009857;
const int g=7;
int fi[maxn];
int a[maxn],inva[maxn],da[maxn],lna[maxn],c[maxn],iinv[maxn];
int f[maxn],fac[maxn],invf[maxn];
int n,k,x;


int mypow(int a,int b)
{
    if(b<0) return mypow(mypow(a,p-2),-b);
    int ans=1;
    while(b)
    {
        if(b&1) ans=1ll*ans*a%p;
        a=1ll*a*a%p;
        b>>=1;
    }
    return ans%p;
}

void get_inv(int n)
{
    iinv[1]=1;
    fac[0]=fac[1]=1;
    for(int i=2;i<=n;i++)
    {
        iinv[i]=1ll*(mod-mod/i)*iinv[mod%i]%mod;
        fac[i]=1ll*fac[i-1]*i%mod;
    }
    invf[n]=mypow(fac[n],mod-2);
    for(int i=n-1;i>=0;i--)
        invf[i]=1ll*invf[i+1]*(i+1)%mod;

}
int getlen(int n,int m)
{
    int len=1,cnt=0;
    while(len<=n+m) len<<=1,cnt++;
    for(int i=0;i<len;i++)
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
    return len;
}

void ntt(int *x,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<fi[i]) swap(x[i],x[fi[i]]);

    for(int i=1;i<len;i<<=1)
    {
        int r=i<<1;
        int wn=mypow(g,f*(p-1)/r);
        for(int j=0;j<len;j+=r)
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int xx=x[j+k],yy=1ll*w*x[j+i+k]%p;
                x[j+k]=(xx+yy)%p;
                x[j+i+k]=((xx-yy)%p+p)%p;
                w=1ll*w*wn%p;
            }
        }
    }
    if(f==-1)
    {
        int invn=mypow(len,p-2);
        for(int i=0;i<len;i++)
            x[i]=1ll*x[i]*invn%p;
    }
}


void inv(int n,int *a,int *b)
{
    if(n==1)
    {
        b[0]=mypow(a[0],p-2);
        return ;
    }
    inv((n+1)>>1,a,b);
    int len=1,cnt=0;
    while(len<=((n-1)<<1)) len<<=1,cnt++;
    for(int i=0;i<len;i++)
    {
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));
        c[i]=(i<n?a[i]:0);
        b[i]=(i<n?b[i]:0);
    }
    ntt(c,len,1);
    ntt(b,len,1);
    for(int i=0;i<len;i++)
        b[i]=(2-(ll)c[i]*b[i]%p+p)%p*b[i]%p;
    ntt(b,len,-1);
    for(int i=n;i<len;i++)
        b[i]=0;
}

void dao(int *a,int *b,int n,int len)
{
    for(int i=1;i<n;i++)
        b[i-1]=(ll)i*a[i]%mod;
    for(int i=n-1;i<len;i++)
        b[i]=0;
}

void ji(int *a,int n,int len)
{
    for(int i=n-2;i>=0;i--)
        a[i+1]=(ll)a[i]*iinv[i+1]%mod;
    a[0]=0;
    for(int i=n;i<len;i++)
        a[i]=0;
}

void ln(int *a,int n)
{
    int len=getlen(n-2,n-1);
    dao(a,da,n,len);
    inv(n-1,a,inva);

    ntt(da,len,1);
    ntt(inva,len,1);
    for(int i=0;i<len;i++) lna[i]=(ll)da[i]*inva[i]%mod;
    ntt(lna,len,-1);

    ji(lna,n,len);
}

void eexp(int *f,int *a,int n)
{
    if(n==1)
    {
        a[0]=1;
        return ;
    }
    eexp(f,a,(n+1)>>1);
    ln(a,n);
    int len=getlen(n-1,n-1);
    for(int i=0;i<n;i++)
        lna[i]=(-lna[i]+f[i]+mod)%mod;

    lna[0]++;
    ntt(a,len,1);
    ntt(lna,len,1);
    for(int i=0;i<len;i++) a[i]=1ll*a[i]*lna[i]%mod;
    ntt(a,len,-1);
    for(int i=n;i<len;i++)
        a[i]=0;

}

int main(void)
{
    scanf("%d%d",&n,&k);
    ++n;
    get_inv(n);
    for(int i=1;i<=k;i++)
    {
        scanf("%d",&x);
        f[x]=1ll*fac[x-1]*invf[x]%mod;
    }
    eexp(f,a,n);
    for(int i=1;i<n;i++)
        printf("%lld\n",1ll*a[i]*fac[i]%mod);
    return 0;

}