链接:https://nanti.jisuanke.com/t/42386
公式化简:
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int mod = 998244353; int inv2, inv6; int ans; int b; ll n; int ksm(int a, int b) { int ret = 1; for(; b; b >>= 1, a = 1LL * a * a % mod) { if(b & 1) ret = 1LL * ret * a % mod; } return ret; } int ni(int a) { return ksm(a, mod - 2); } void init() { int tmp1 = (n+1) % mod, tmp2 = (n+b+1) % mod, tmp3 = (n-b+mod) % mod; int tmp4 = n % mod, tmp5 = (n+1) % mod, tmp6 = (2*n+1) % mod; ans = 1LL * tmp1 * tmp2 % mod * tmp3 % mod * inv2 % mod; (ans += 1LL * b * (b+1) % mod * (2*b+1) % mod * inv6 % mod) %= mod; (ans += mod - 1LL * tmp4 * tmp5 % mod * tmp6 % mod * inv6 % mod) %= mod; } void work() { for(int i = 2; i <= b; i++) { int k = 1; ll tmp = i; while(tmp < n) { tmp *= i; k++; } if(tmp > n) k--; int tmp1 = (n+1) % mod * k % mod; int tmp2 = 1LL * i * (ksm(i, k) - 1 + mod) % mod * ni(i-1) % mod; (ans += 1LL * (tmp1 - tmp2 + mod) * i % mod) %= mod; } } int main() { inv2 = ni(2); inv6 = ni(6); scanf("%lld", &n); if(n == 2) {printf("2\n"); return 0;} if(n == 3) {printf("7\n"); return 0;} b = sqrt(n); init(); work(); printf("%d\n", ans); return 0; }