A Puzzle: X-Sums Sudoku

题意:考虑一宫大小为 2n×2m2^n\times 2^m 的方形数独, 求横排字典序最小(4×24\times 2 的数独如下)的数度中第 xx 行或列的前或后 XX 个数的和,其中 XX 为第 xx 行或列的第一个数字。

img

其中第二行 8=3+4+18=3+4+1,因为第一个数字为 33 代表取 33 个数字。第二列最下方的 3434 表示取本列后 77 个数字的和。n,m30n,m \leq 30TT 测,T1×105T \leq 1\times 10^5

解法:可以参考以下的打表代码:

#include <bits/stdc++.h>
using namespace std;
const int N = 1 << 7;
bool viscol[N][N*N], visrow[N][N*N], vispalace[N][N*N];
int ans[N * N][N * N];
int main()
{
    int n, m;
    scanf("%d%d", &n, &m);
    n = 1 << n;
    m = 1 << m;
    for (int i = 0; i < n * n * m * m;i++)
    {
        int row = i / (n * m), col = i % (n * m);
        int palace_row = row / n, palace_col = col / m;
        int palace_id = palace_row * n + palace_col;
        for (int j = 0; j < n * m; j++)
            if(!viscol[col][j] && !visrow[row][j] && !vispalace[palace_id][j])
            {
                viscol[col][j] = visrow[row][j] = vispalace[palace_id][j] = 1;
                ans[row][col] = j;
                break;
            }
    }//以上为暴力
    for (int i = 0; i < n * m;i++)
    {
        if(i == 0)
        {
            printf("+");
            for (int j = 0; j < n * m;j++)
                printf("--%c", "-+"[j % m == m - 1]);
            printf("\n");
        }
        printf("|");
        for (int j = 0; j < n * m;j++)
        {
            assert(((i / n) ^ j ^ (i % n * m)) == ans[i][j]);//O(1)解法
            printf("%02X%c", (i / n) ^ j ^ (i % n * m), " |"[j % m == m - 1]);
        }
        printf("\n");
        if (i % n == n - 1)
        {
            printf("+");
            for (int j = 0; j < n * m;j++)
                printf("--%c", "-+"[j % m == m - 1]);
            printf("\n");
        }
    }
    return 0;
}

通过打表或者观察样例可以得到,若将全部数字减一,并且下标均为 0-base,则第 ii 行第 jj 列的数字为 i2nj2m(imod2n)\displaystyle \left \lfloor \dfrac{i}{2^n}\right \rfloor \oplus j \oplus 2^m(i \bmod 2^n)。其中,i2n\left \lfloor \dfrac{i}{2^n}\right \rfloor 表示了横行宫的贡献,jj 为列贡献,2m(imod2n)2^m(i \bmod 2^n) 为宫内行贡献。同时容易注意到,整个数独是中心对称的。因而如果要算第 xx 列从下往上的答案,可以转化到第 2n+mx+12^{n+m}-x+1 列的答案,从右往左同理。

对于横行,X=x2n2m(xmod2n)X=\left \lfloor \dfrac{x}{2^n}\right \rfloor \oplus 2^m(x \bmod 2^n)。其答案为 X+1+i=0X(x2n2m(xmod2n))i\displaystyle X+1+\sum_{i=0}^X\left(\left \lfloor \dfrac{x}{2^n}\right \rfloor \oplus 2^m(x \bmod 2^n)\right) \oplus i。对于此类 iix\sum_{i}i \oplus x,其中 xx 为一定值的,可以分位考虑,考虑第 jj 位为 1100 的个数。若 xx 的第 jj 位为 00 则计入 11 的个数,否则计入 00 的个数。

对于纵列,X=xX=x,其答案为 X+1+i=0X(i2n2m(imod2n))X\displaystyle X+1+\sum_{i=0}^X\left(\left \lfloor \dfrac{i}{2^n}\right \rfloor \oplus 2^m(i \bmod 2^n)\right) \oplus X。不难发现,i2n[0,2m1]\left \lfloor \dfrac{i}{2^n}\right \rfloor \in [0,2^m-1],而 2m(imod2n)2^m (i \bmod 2^n) 对答案的贡献一定在第 mm 个二进制位之上。因而枚举到第 jj 位时,需要根据当前位置进行平移——当 jmj \geq m 时计算的时 imod2ni \bmod 2^n,而 j<mj<m 计算的为 i2n\left \lfloor \dfrac{i}{2^n}\right \rfloor

因而单次询问复杂度为 O(n+m)\mathcal O(n+m)

#include <bits/stdc++.h>
using namespace std;
void print(__int128_t x)
{
    if(!x)
    {
        printf("0\n");
        return;
    }
    string ans = "";
    while(x)
    {
        ans += x % 10 + 48;
        x /= 10;
    }
    reverse(ans.begin(), ans.end());
    printf("%s\n", ans.c_str());
}
long long count(long long n, int digit)
{
    long long block = n >> (digit + 1), res = n - (block << (digit + 1));
    return (block << digit) + max(res - (1ll << digit), 0ll);
}
char buf[40];
int main()
{
    int t, n, m;
    long long x;
    scanf("%d", &t);
    while(t--)
    {
        scanf("%d%d%s%lld", &n, &m, buf, &x);
        x--;
        if (buf[0] == 'b')
        {
            buf[0] = 't';
            x = (1ll << (n + m)) - x - 1;
        }
        else if (buf[0] == 'r')
        {
            buf[0] = 'l';
            x = (1ll << (n + m)) - x - 1;
        }
        if (buf[0] == 'l')
        {
            x = (x >> n) ^ ((x - ((x >> n) << n)) << m);
            __int128_t ans = x + 1;
            for (int k = 0; k < n + m;k++)
            {
                __int128_t now = count(x + 1, k);
                if (x >> k & 1)
                    now = x + 1 - now;
                ans += now << k;
            }
            print(ans);
        }
        else
        {
            __int128_t ans = x + 1;
            for (int k = 0; k < n + m;k++)
            {
                __int128_t now = count(x + 1, k < m ? n + k : k - m);
                if (x >> k & 1)
                    now = x + 1 - now;
                ans += now << k;
            }
            print(ans);
        }
    }
    return 0;
}