ACM模版

描述

题解

正如评论区那个初一的大佬所说的那样……

初一就玩转 FFT 感觉有些丧心病狂啊……内心溅射一万点伤害!我现在连模版都用不好……

代码

#include <iostream>
#include <stdio.h>
#include <cmath>
#include <algorithm>

using namespace std;

typedef unsigned long long ull;

struct complex
{
    double real, imag;

    complex(double real1 = 0.0, double imag1 = 0.0)
    {
        real = real1;
        imag = imag1;
    }

    void operator == (const complex &a)
    {
        real = a.real;
        imag = a.imag;
    }
};

complex operator * (const complex &n1, const complex &n2)
{
    return complex(n1.real * n2.real - n1.imag * n2.imag, n1.real * n2.imag + n1.imag * n2.real);
}

complex operator + (const complex &n1, const complex &n2)
{
    return complex(n1.real + n2.real, n1.imag + n2.imag);
}

complex operator - (const complex &n1, const complex &n2)
{
    return complex(n1.real - n2.real, n1.imag - n2.imag);
}

const static int WORDSIZE = 500000;
const static int SIZEN = 262144;
const static int MOD = 10000;
const static int N = 100000;

int cnta1, cntb1;
int a1[SIZEN];
int b1[SIZEN];
int rev[SIZEN];
char word[WORDSIZE];
complex e1[SIZEN >> 1];
complex e2[SIZEN >> 1];
complex A[SIZEN];
complex w1[SIZEN];
complex w2[SIZEN];

struct BigInteger
{
    int l, r;
};

int cntp1, cntq1;
BigInteger p1[N];
BigInteger q1[N];

void Reverse(int N, int bitlength)
{
    for (int i = 0; i < N; i++)
    {
        int temp = i, len = bitlength;
        while (temp)
        {
            rev[i] |= (temp & 1) << (len--);
            temp >>= 1;
        }
    }
}

void DFT(complex *A, complex *w1, complex *e, int N)
{
    int k = SIZEN / N;
    int bit = 0;
    while (k)
    {
        bit++;
        k >>= 1;
    }
    for (int i = 0; i < N; i++)
    {
        w1[i] = A[rev[i] >> bit];
    }
    for (int h = 2; h <= N; h <<= 1)
    {
        int h1 = h, bit1 = -1;
        while (h1)
        {
            bit1++;
            h1 >>= 1;
        }
        for (int i = 0; i < N; i += h)
        {
            int k = h >> 1, q = SIZEN >> bit1;
            for (int j = i; j < i + k; j++)
            {
                complex u = w1[j];
                complex t = e[q*(j - i)] * w1[j + k];
                w1[j] = u + t;
                w1[j + k] = u - t;
            }
        }
    }
}

void Multiplication(int l1, int r1, int l2, int r2, int *a, int *b, int &cntb)
{
    int length1 = r1 - l1, length2 = r2 - l2;
    if (length1 <= 40 && length2 <= 40)
    {
        int tempcntb = cntb + length1 + length2 + 1;
        for (int i = cntb; i <= tempcntb; i++)
        {
            b[i] = 0;
        }
        tempcntb = cntb;
        for (int i = l1; i < r1; i++)
        {
            int number = a[i];
            for (int j = l2; j < r2; j++)
            {
                int pos = cntb + i + j - l1 - l2;
                b[pos] += a[j] * number;
                if (b[pos] >= MOD)
                {
                    b[pos + 1] += b[pos] / MOD;
                    b[pos] %= MOD;
                }
            }
        }
        int length = length1 + length2 + 1 + tempcntb;
        for (int k = tempcntb; k <= length; k++)
        {
            if (b[k] != 0)
            {
                cntb = k + 1;
            }
        }

        return ;
    }

    int tempLength = max(length1, length2);
    int N = 1, bitlength = 0;
    while (N < tempLength)
    {
        N <<= 1;
        bitlength++;
    }
    N <<= 1;
    for (int i = 0; i < length1; i++)
    {
        A[i] = complex(a[i + l1], 0);
    }
    for (int i = length1; i < N; i++)
    {
        A[i] = complex(0, 0);
    }
    DFT(A, w1, e1, N);
    for (int i = 0; i < length2; i++)
    {
        A[i] = complex(a[i + l2], 0);
    }
    for (int i = length2; i < N; i++)
    {
        A[i] = complex(0, 0);
    }
    DFT(A, w2, e1, N);
    for (int i = 0; i < N; i++)
    {
        A[i] = w2[i] * w1[i];
    }
    DFT(A, w1, e2, N);
    ull temp = 0, temp1 = 0;
    int anslength = 0, lastcntb = cntb;
    double ND = N;
    for (int i = 0; i < N; i++)
    {
        temp += w1[i].real / ND + 0.1;
        temp1 = temp / MOD; temp %= MOD;
        b[cntb++] = (int)temp;
        if (temp != 0)
        {
            anslength = i + 1;
        }
        temp = temp1;
    }
    cntb = lastcntb + anslength;
}

void Print(int len, int *a)
{
    char *out = word + WORDSIZE;
    *(--out) = '\n';
    int linesize = 1000;
    int cnt = len * 4;
    if (a[len - 1] <= 9)
    {
        cnt -= 3;
    }
    else if (a[len - 1] <= 99)
    {
        cnt -= 2;
    }
    else if (a[len - 1] <= 999)
    {
        cnt--;
    }
    cnt %= linesize;
    cnt = linesize - cnt;
    for (int i = 0; i < len; i++)
    {
        int number = a[i];
        *(--out) = number % 10 + '0';
        number /= 10;
        if (++cnt == linesize)
        {
            *(--out) = '\n'; cnt = 0;
        }
        *(--out) = number % 10 + '0';
        number /= 10;
        if (++cnt == linesize)
        {
            *(--out) = '\n'; cnt = 0;
        }
        *(--out) = number % 10 + '0';
        number /= 10;
        if (++cnt == linesize)
        {
            *(--out) = '\n'; cnt = 0;
        }
        *(--out) = number % 10 + '0';
        number /= 10;
        if (++cnt == linesize)
        {
            *(--out) = '\n'; cnt = 0;
        }
    }

    while (*out == '0' || *out == '\n')
    {
        out++;
    }
    fwrite(out, 1, word + WORDSIZE - out, stdout);
}

void factorial()
{
    BigInteger *p = p1, *q = q1;
    int cntp = cntp1, cntq = cntq1;
    int cnta = cnta1, cntb = cntb1;
    int *a = a1, *b = b1;
    while (1)
    {
        if (cntp == 1)
        {
            Print(p[0].r, a);
            break;
        }
        for (int i = 0; i < (cntp >> 1); i++)
        {
            q[cntq].l = cntb;
            Multiplication(p[i].l, p[i].r, p[cntp - i - 1].l, p[cntp - i - 1].r, a, b, cntb);
            q[cntq++].r = cntb;
        }
        if ((cntp & 1) == 1)
        {
            q[cntq].l = cntb;
            int l = p[(cntp >> 1)].l;
            int r = p[(cntp >> 1)].r;
            for (int i = l; i < r; i++)
            {
                b[cntb++] = a[i];
            }
            q[cntq++].r = cntb;
        }
        swap(p, q);
        swap(a, b);
        swap(cntp, cntq);
        swap(cnta, cntb);
        cntq = 0;
        cntb = 0;
    }
}

void init(int n)
{
    int l = 0;
    int n1 = min(n, 99);
    int number = 1;
    for (int i = 1; i <= n1; i++)
    {
        if (number * i >= MOD)
        {
            a1[cnta1++] = number;
            p1[cntp1].l = l;
            p1[cntp1++].r = cnta1;
            l = cnta1;
            number = i;
        }
        else number *= i;
    }
    if (number > 1)
    {
        a1[cnta1++] = number;
        p1[cntp1].l = l;
        p1[cntp1++].r = cnta1;
        l = cnta1;
    }
    for (int i = n1 + 1; i <= n; i++)
    {
        if (i >= MOD)
        {
            a1[cnta1++] = i % MOD;
            a1[cnta1++] = i / MOD;
        }
        else a1[cnta1++] = i;
        p1[cntp1].l = l;
        p1[cntp1++].r = cnta1;
        l = cnta1;
    }
    int bitlength = 0, N = SIZEN >> 1;
    while (N)
    {
        bitlength++;
        N >>= 1;
    }
    Reverse(SIZEN, bitlength);
    double pin2 = 2.0 * acos(-1.0) / SIZEN;
    int N1 = SIZEN >> 2, N2 = SIZEN >> 1;
    for (int i = 0; i <= N1; i++)
    {
        e1[i] = complex(cos(pin2 * i), sin(pin2 * i));
    }
    for (int i = N1 + 1; i < N2; i++)
    {
        e1[i] = complex(-e1[N2 - i].real, e1[N2 - i].imag);
    }
    for (int i = 0; i < N2; i++)
    {
        e2[i] = complex(e1[i].real, -e1[i].imag);
    }
}

int n;

int main()
{
    scanf("%d", &n);

    init(n);
    factorial();

    return 0;
}