E 末日时在做什么?有没有空?可以来学数学吗?

打完表发现规律并不难找,难点在于如何模拟实现 alt 注意当n<=m×m时(这里我的n是自减了的,不自减就是n<=m×m+1),要利用二分去查找n到达了第几项,不然会alt

表中可以发现对于一个输入的m(这里因为前面已经自减n了,所以忽略第一个值m,即a1),首项=2*m-1,公差d=-2,一直递减,这里ai为项数,比如输入100 4时,第一项=7,包含1 1 2 1 3 1 4,第二项=5,包含2 2 3 2 4......二分找出n能够达到第几项,模拟求和即可

当n>=m×m(同上,n已经自减,这里忽略a1)时,求和比较容易,前面m×m项即是上面的等差数列,后面即为m+1 1 m+1 2...m+1 m,m+2 1 m+2 2...m+2 m,......每项包含2×m个数,模拟求和即可

当m=1时,另外讨论求值,这个很简单,不多过诉

附代码:

#include<bits/stdc++.h>
#define endl "\n"
#define int long long
using namespace std;
const int mod = 1e9 + 7;
int n, m;
//unordered_map<int, int>cnt;
//int a[100005];
void solve()
{
    cin >> n >> m;
    //cnt.clear();
    //cnt[m]++;
    //a[1] = m;
    //for (int i = 2; i <= n; i++)
    //{
    //    a[i] = cnt[a[i - 1]];
    //    cnt[a[i]]++;
    //}
    //cout << "a[i]的值:" << endl;
    //for (int i = 1; i <= n; i++)cout << a[i] << " ";
    //cout << endl;
    //cout << "sum[i]:" << endl;
    //int sum = 0;
    //for (int i = 1; i <= n; i++)
    //{
    //    sum += a[i];
    //    cout << sum << " ";
    //}
    //cout << endl;
    if (n == 1)
    {
        cout << m << endl;
    }
    else if (m == 1)
    {
        int sum = 1;
        n--;
        int t = (n + 1) / 2;
        sum = (sum + t) % mod;
        if (n % 2 == 1)t--;
        int a1 = 2, an = 1 + t;
        sum = (sum + (a1 + an) * t / 2 % mod) % mod;
        cout << sum << endl;
    }
    else
    {
        int sum = m;
        n--;
        int maxx = m * m;
        if (n <= m * m)//4<3*3
        {
            //这里得二分求总共有多少项,不然会超时
            int a1 = 2 * m - 1;//5
            int l = 1, r = m;//l=1,r=3;
            int pos = 1;
            while (l <= r)
            {
                int mid = (l + r) / 2;
                int an = a1 - 2 * (mid - 1);
                if (mid * (a1 + an) / 2 >= n)
                {
                    pos = mid;
                    r = mid - 1;
                }
                else l = mid + 1;
            }
            //cout << "pos1:" << pos << endl;//1
            int an = a1 - 2 * (pos - 1);
            int t = (a1 + an) * pos / 2;
            if (t > n)
            {
                pos--;
            }
            //cout << "pos2:" << pos << endl;//0
            sum = (sum + (1 + pos) * pos / 2 % mod) % mod;
            for (int i = 1; i <= pos; i++)
            {
                sum = (sum + i * (m - i) % mod + (i + 1 + m) * (m - i) / 2 % mod) % mod;
            }
            if (t > n)
            {
                t -= an;
                int diff = n - t;
                sum = (sum + pos + 1) % mod;//4
                diff--;//3
                sum = (sum + (pos + 1) * ((diff + 1) / 2) % mod) % mod;//6
                sum = (sum + (pos + 1 + 1 + pos + diff / 2 + 1) * (diff / 2) / 2 % mod) % mod;//8
            }
        }
        else
        {
            sum = (sum + (1 + m) * m / 2 % mod) % mod;
            for (int i = 1; i <= m; i++)
            {
                sum = (sum + i * (m - i) % mod + (i + 1 + m) * (m - i) / 2 % mod) % mod;
            }
            int diff = n - m * m;
            int t = diff / (2 * m);
            int tt = diff % (2 * m);
            int a1 = m + 1;
            int an = m + t;
            sum = (sum + m * (a1 + an) * t / 2 % mod) % mod;
            int ttt = (1 + m) * m / 2 % mod;
            if (ttt == 1)
            {
                cout << sum << endl;
                return;
            }
            sum = (sum + t * ttt % mod) % mod;
            int cnt = 1;
            while (tt)
            {
                sum = (sum + an + 1) % mod;
                tt--;
                if (tt == 0)break;
                sum = (sum + cnt) % mod;
                cnt++;
                tt--;
            }
        }
        cout << sum << endl;
    }
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    int T = 1;
    cin >> T;
    //init();
    while (T--)
    {
        solve();
    }
}