题意:求两个大数的乘积
解题方法:之前已经用FFT做过了,今天学习一下NTT,记录一下模板。
先所以下NTT,具体的讲解可以看见这里 为了避免FFT在复数和浮点运算中出现精度问题,所以在某些情况下使用NTT。记录一下NTT的模板,其实和FFT很多一样。对了这个具体原理可以看这个博客:见这里

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1<<18;
const int G = 3, P = (479<<21) + 1; //G为原根,P为大素数,可以处理2^21次范围

LL quick(LL x, LL n) {
    LL ret = 1;
    for(; n; n >>= 1) {
        if(n & 1) ret = ret * x % P;
        x = x * x % P;
    }
    return ret;
}

LL A[N], B[N];
void rader(LL* y, int len) {
    for(int i = 1, j = len / 2; i < len - 1; i++) {
        if(i < j) swap(y[i], y[j]);
        int k = len / 2;
        while(j >= k) {j -= k; k /= 2;}
        if(j < k) j += k;
    }
}
void ntt(LL* y, int len, int op) {
    rader(y, len);
    for(int h = 2; h <= len; h <<= 1) {
        LL wn = quick(G, (P - 1) / h);
        if(op == -1) wn = quick(wn, P - 2);
        for(int j = 0; j < len; j += h) {
            LL w = 1;
            for(int k = j; k < j + h / 2; k++) {
                LL u = y[k];
                LL t = w * y[k + h / 2] % P;
                y[k] = (u + t) % P;
                y[k + h / 2] = (u - t + P) % P;
                w = w * wn % P;
            }
        }
    }
    if(op == -1) {
        LL inv = quick(len, P - 2);
        for(int i = 0; i < len; i++) y[i] = y[i] * inv % P;
    }
}
char s1[100010], s2[100010];
LL ans[N];

int main()
{
    while(scanf("%s%s", s1, s2) != EOF)
    {
        int len1 = strlen(s1);
        int len2 = strlen(s2);
        memset(A, 0, sizeof(A));
        memset(B, 0, sizeof(B));
        for(int i = len1 - 1; i >= 0; i--) A[i] = s1[len1-i-1] - '0';
        for(int i = len2 - 1; i >= 0; i--) B[i] = s2[len2-i-1] - '0';
        int len = 1;
        while(len < len1 * 2 || len < len2 * 2) len <<= 1;
        ntt(A, len, 1);
        ntt(B, len, 1);
        for(int i = 0; i < len; i++) A[i] = A[i] * B[i] % P;
        ntt(A, len, -1);
        memset(ans, 0, sizeof(ans));
        for(int i = 0; i < len; i++){
            ans[i] += A[i];
            if(ans[i] >= 10){
                ans[i+1] += ans[i]/10;
                ans[i]%=10;
            }
        }
        int pos = 0;
        for(int i = len-1; i >= 0; i--){
            if(ans[i]){
                pos = i;
                break;
            }
        }
        for(int i = pos; i >= 0; i--){
            printf("%lld", ans[i]);
        }
        printf("\n");
    }
    return 0;
}