题目传送门

题目大意

解题思路

这还是camp的题,当时不会,甚至想过莽,因为签完到别的题也不会了
原式等于
因为,所以
那么原式就相当于
其中的求和就相当于是卷积,
我们构造一个以作为指数,作为底数的项的多项式
则答案就是该多项式平方后
所以先然后枚举求和
其中
也就是关于的二次剩余

AC代码

//#pragma GCC optimize(3,"Ofast","inline")
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <string>
#include <iostream>
#include <list>
#include <cstdlib>
#include <bitset>
#include <assert.h>
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
// char buf[(1 << 21) + 1], * p1 = buf, * p2 = buf;
// #define int long long
#define lowbit(x) (x & (-x))
#define lson root << 1, l, mid
#define rson root << 1 | 1, mid + 1, r
#define pb push_back
typedef unsigned long long ull;
typedef long long ll;
typedef std::pair<ll, ll> pii;
#define bug puts("BUG")
const long long INF = 0x3f3f3f3f3f3f3f3fLL;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
const double eps = 1e-6;
template <class T>
inline void read(T &x)
{
    int sign = 1;char c = getchar();x = 0;
    while (c > '9' || c < '0'){if (c == '-')sign = -1;c = getchar();}
    while (c >= '0' && c <= '9'){x = x * 10 + c - '0';c = getchar();}
    x *= sign;
}
#ifdef LOCAL
    FILE* _INPUT=freopen("input.txt", "r", stdin);
    // FILE* _OUTPUT=freopen("output.txt", "w", stdout);
#endif
using namespace std;
const int qr2 = 116195171;
const int maxn = 1e6 + 10;
ll a[maxn];
int r[maxn];
ll qmod(ll a,ll n)
{
    ll ans = 1;
    while(n)
    {
        if (n & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return ans;
}
void init(int len, int l)
{
    for (int i = 0; i < len; ++i)
    {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    }
}

void change(ll y[], int len)
{
    for (int i = 0; i < len; ++i)
        if(i<r[i])
            swap(y[i], y[r[i]]);
}

void NTT(ll y[], int len, int rev)
{
    change(y,len);
    for(int h = 2; h <= len; h <<= 1)
    {
        ll wn = qmod(3, (mod - 1) / h);
        for (int j = 0; j < len; j += h)
        {
            ll w = 1;
            for (int k = j; k < j + h / 2; k++)
            {
                ll u = y[k];
                ll t = (w * y[k + h / 2]) % mod;
                y[k] = (u + t) % mod;
                y[k + h / 2] = (u - t + mod) % mod;
                w = (w * wn) % mod;
            }
        }
    }
    if (rev == -1)
    {
        reverse(y + 1, y + len);
        ll inv = qmod(len, mod - 2);
        for(int i = 0;i < len;i++)
            (y[i] *= inv) %= mod;
    }
}

int main()
{
    int invqr2 = qmod(qr2, mod - 2);
    ll n, x;
    int len = 1, l = 0;
    read(n);
    for (int i = 0; i < n; ++i)
    {
        read(x);
        int t = qmod(invqr2, x * x);
        while (len <= x * 2 + 1)
        {
            len <<= 1;
            ++l;
        }
        (a[x] += t) %= mod;
    }
    init(len, l);
    NTT(a, len, 1);
    for (int i = 0; i < len; ++i)
        a[i] = a[i] * a[i] % mod;
    NTT(a, len, -1);
    ll res = 0;
    for (ll i = 0; i < len; ++i)
    {
        res = (res + a[i] * qmod(qr2, i * i) % mod) % mod;
    }
    printf("%lld\n", (res + mod) % mod);
}