D y t s <mtext>   </mtext> m a t h Dyt's\ math Dyts math

, <mtext>   </mtext> 998244353 <mtext>   </mtext> : 求解下式,对\ 998244353\ 取模: , 998244353 :

<munderover> i = 1 n </munderover> C i k b i \sum_{i=1}^n C_i^k b^i i=1nCikbi

k 5 × 1 0 5 , n 1 0 18 , 2 b 998244353. k \le 5 \times 10^5, n \le 10^{18},2 \le b \le 998244353. k5×105,n1018,2b998244353.


<mstyle mathcolor="red"> </mstyle> \color{red}{正解部分}

F k = i = 1 n C i k b i F_k = \sum\limits_{i=1}^n C_i^k b^i Fk=i=1nCikbi, 则 F k 1 = i = 1 n C i k 1 b i F_{k-1} = \sum\limits_{i=1}^n C_i^{k-1} b^i Fk1=i=1nCik1bi,

b F k 1 = i = 2 n + 1 C i 1 k 1 b i bF_{k-1} = \sum\limits_{i=2}^{n+1} C_{i-1}^{k-1} b^i bFk1=i=2n+1Ci1k1bi,

b F k = i = 2 n + 1 C i 1 k b i bF_k = \sum\limits_{i=2}^{n+1} C_{i-1}^k b^i bFk=i=2n+1Ci1kbi,

b F k 1 + b F k = i = 2 n + 1 ( C i 1 k + C i 1 k 1 ) b i = i = 2 n + 1 C i k b i = F k + C n + 1 k b n + 1 C 1 k b \therefore bF_{k-1}+bF_k = \sum\limits_{i=2}^{n+1}\left(C_{i-1}^k +C_{i-1}^{k-1}\right) b^i = \sum\limits_{i=2}^{n+1}C_i^k b^i = F_k + C_{n+1}^kb^{n+1}-C_1^kb bFk1+bFk=i=2n+1(Ci1k+Ci1k1)bi=i=2n+1Cikbi=Fk+Cn+1kbn+1C1kb .

F k = b F k 1 C n + 1 k b n + 1 + C 1 k b 1 b \therefore F_k = \frac{bF_{k-1}-C_{n+1}^kb^{n+1}+C_1^kb}{1-b} Fk=1bbFk1Cn+1kbn+1+C1kb

F 0 = i = 1 n b i = b ( 1 b n ) 1 b 又\because F_0 = \sum_{i=1}^n b^i = \frac{b(1-b^n)}{1-b} F0=i=1nbi=1bb(1bn)

所以可以 O ( K ) O(K) O(K) 得到 F k F_k Fk .


<mstyle mathcolor="red"> </mstyle> \color{red}{实现部分}

#include<bits/stdc++.h>
#define reg register
typedef long long ll;

const int maxn = 500005;
const int mod = 998244353;

int K;
int B;

ll N;

int Ksm(int a, ll b){int s=1;while(b){if(b&1)s=1ll*s*a%mod;a=1ll*a*a%mod;b>>=1;}return s;}

int main(){
        scanf("%lld%d%d", &N, &B, &K);
        int inv_b = Ksm(B-1, mod-2);
        int Ans = 1ll*B*(Ksm(B, N)-1)%mod*inv_b % mod;
        int C = (N + 1)%mod;
        for(reg int i = 1; i <= K; i ++){
                int C1kb = (i==1)*B;
                Ans = -1ll*B*Ans % mod;
                Ans += 1ll*C*Ksm(B, N+1)%mod;
                Ans -= C1kb;
                Ans %= mod, Ans += mod, Ans %= mod;
                Ans = 1ll*Ans*inv_b % mod;
                C = (N-i+1)%mod*C%mod*Ksm(i+1, mod-2)%mod;
        }
        printf("%d\n", Ans);
        return 0;
}