分析

类似 这样的式子,其实我们普通的快速幂是没法解决的,所以考虑拓展欧拉定理。 而题面上已经保证 了,所以可以直接根据 的来化简了 ,我们现在已经把幂降下来了。现在原问题等同于求这个东西 。那么我们现在有个初步的想法,令 表示 ,因为这个是在 意义下进行的,所以考虑枚举 有多少个数。那么总的方案数就可以写作,那么 ,这个显然是个多项式的卷积形式。所以考虑用快速数论变化优化一下,时间复杂度为 。也可以用生成函数来理解。 表示值为 可以出现的次数。那么令 的普通生成函数 ,同理可以求得 。那么最后 。有什么细节和代码上的不懂,欢迎私聊啊。

代码

#include<bits/stdc++.h>
using namespace std;
const int N = 3e6 + 10,g = 3,p = 998244353,gi = 332748118;
#define LL long long
LL L,R[N],A[N],B[N],C[N],limit = 1,inv;
LL phi(LL x) {
    LL ans = x;
    for(LL i = 2;i * i <= x;i++) {
        if(!(x % i)) {
            ans = (ans / i * (i - 1));
            while(!(x % i)) x /= i;
        }
    }
    if(x > 1) ans = (ans / x * (x - 1));
    return ans;
}
LL fastpow(LL a,LL b,LL Mod) {
    LL x = 1;
    for(;b;b >>= 1,a = a * a % Mod) {
        if(b&1) x = x * a % Mod;
    }
    return x;
}
void ntt(LL *a,LL type) {
    for(LL i = 0;i < limit;i++) {
        if(i < R[i]) swap(a[i],a[R[i]]);
    }
    for(LL mid = 1;mid < limit;mid <<= 1) {
        LL wn = fastpow((type == 1)?g:gi,(p-1)/(mid << 1),p);
        for(LL i = 0;i < limit;i += (mid << 1)) {
            LL w = 1;
            for(LL j = 0;j < mid;j++,w = w * wn % p) {
                LL x = a[i + j] , y = w * a[i + j + mid] % p;
                a[i + j] = (x + y) % p;a[i + j + mid] = (x - y + p) % p; 
            }
        }
    }
} 
LL x,y,z,n,k;
int main() {
    cin >> n >> k;
    x = n;
    y = fastpow(n,n,k - 1) + k - 1;
    z = fastpow(n,fastpow(n,n,phi(k - 1)) + phi(k - 1),k - 1) + k - 1;
//    cout << x << " " << y << " " << z << endl;
    for(int i = 1;i <= n;i++) {
        A[fastpow(i,x,k)]++;B[fastpow(i,y,k)]++;C[fastpow(i,z,k)]++;
//        cout << fastpow(i,x,k)<<" "<<fastpow(i,y,k)<<" "<<fastpow(i,z,k)<<endl;
    }
    while(limit <= k + k - 2) limit <<= 1,L++;
    inv = fastpow(limit,p - 2,p);
    for(int i = 0;i < limit;i++) R[i] = (R[i>>1]>>1)|((i&1) << (L-1));
    ntt(A,1);ntt(B,1);
    for(int i = 0;i < limit;i++) A[i] = (A[i] * B[i]) % p;

    ntt(A,-1);//for(int i = 0;i < limit;i++) cout << A[i] << endl;
    LL ans = 0;
    for(int i = 0;i < limit;i++) {
        ans = (ans + C[i % k] * inv % p * A[i] % p) % p;//cout << A[i % k] * inv % p << " " << C[i] << endl;
    }
    printf("%lld\n",ans);
    return 0;
}