链接:https://nanti.jisuanke.com/t/42386
公式化简:图片说明

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
int inv2, inv6;
int ans;
int b;
ll n;

int ksm(int a, int b)
{
    int ret = 1;
    for(; b; b >>= 1, a = 1LL * a * a % mod)
    {
        if(b & 1) ret = 1LL * ret * a % mod;
    }
    return ret;
}
int ni(int a)
{
    return ksm(a, mod - 2);
}
void init()
{
    int tmp1 = (n+1) % mod, tmp2 = (n+b+1) % mod, tmp3 = (n-b+mod) % mod;
    int tmp4 = n % mod, tmp5 = (n+1) % mod, tmp6 = (2*n+1) % mod;
    ans = 1LL * tmp1 * tmp2 % mod * tmp3 % mod * inv2 % mod;
    (ans += 1LL * b * (b+1) % mod * (2*b+1) % mod * inv6 % mod) %= mod;
    (ans += mod - 1LL * tmp4 * tmp5 % mod * tmp6 % mod * inv6 % mod) %= mod;
}
void work()
{
    for(int i = 2; i <= b; i++)
    {
        int k = 1; ll tmp = i;
        while(tmp < n)
        {
            tmp *= i;
            k++;
        }
        if(tmp > n) k--;
        int tmp1 = (n+1) % mod * k % mod;
        int tmp2 = 1LL * i * (ksm(i, k) - 1 + mod) % mod * ni(i-1) % mod;
        (ans += 1LL * (tmp1 - tmp2 + mod) * i % mod) %= mod;
    }
}
int main()
{
    inv2 = ni(2); inv6 = ni(6);
    scanf("%lld", &n);
    if(n == 2) {printf("2\n"); return 0;}
    if(n == 3) {printf("7\n"); return 0;}
    b = sqrt(n);
    init();
    work();
    printf("%d\n", ans);

    return 0;
}