链接:https://nanti.jisuanke.com/t/42386
公式化简:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
int inv2, inv6;
int ans;
int b;
ll n;
int ksm(int a, int b)
{
int ret = 1;
for(; b; b >>= 1, a = 1LL * a * a % mod)
{
if(b & 1) ret = 1LL * ret * a % mod;
}
return ret;
}
int ni(int a)
{
return ksm(a, mod - 2);
}
void init()
{
int tmp1 = (n+1) % mod, tmp2 = (n+b+1) % mod, tmp3 = (n-b+mod) % mod;
int tmp4 = n % mod, tmp5 = (n+1) % mod, tmp6 = (2*n+1) % mod;
ans = 1LL * tmp1 * tmp2 % mod * tmp3 % mod * inv2 % mod;
(ans += 1LL * b * (b+1) % mod * (2*b+1) % mod * inv6 % mod) %= mod;
(ans += mod - 1LL * tmp4 * tmp5 % mod * tmp6 % mod * inv6 % mod) %= mod;
}
void work()
{
for(int i = 2; i <= b; i++)
{
int k = 1; ll tmp = i;
while(tmp < n)
{
tmp *= i;
k++;
}
if(tmp > n) k--;
int tmp1 = (n+1) % mod * k % mod;
int tmp2 = 1LL * i * (ksm(i, k) - 1 + mod) % mod * ni(i-1) % mod;
(ans += 1LL * (tmp1 - tmp2 + mod) * i % mod) %= mod;
}
}
int main()
{
inv2 = ni(2); inv6 = ni(6);
scanf("%lld", &n);
if(n == 2) {printf("2\n"); return 0;}
if(n == 3) {printf("7\n"); return 0;}
b = sqrt(n);
init();
work();
printf("%d\n", ans);
return 0;
}

京公网安备 11010502036488号