算法知识点: 搜索,剪枝

复杂度:

解题思路:

在搜索时分别记录每行、每列、每个九宫格内当前未填写的数字有哪些。
这里采用位运算来加速:

  1. 每行、每列、每个九宫格内,分别用一个9位的二进制数来表示哪些数字可填。
  2. 每个空格内所有可选的数字就是其所在行、列、九宫格内可选数字的交集,这里直接将三个9位二进制数按位与(&)即可求出交集,然后通过运算可以快速枚举出所有是1的位。

另外还需要优化搜索顺序,每次选择分支最少的空格来枚举。

C++ 代码:

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std; const int N = 9;
 
int ones[1 << N], map[1 << N];
int row[N], col[N], cell[3][3];
int g[N][N];
int ans = -1;
 
inline int lowbit(int x)
{
    return x &-x;
}
 
void init()
{
    for (int i = 0; i < N; i++) row[i] = col[i] = (1 << N) - 1;
    for (int i = 0; i < 3; i++)
        for (int j = 0; j < 3; j++)
            cell[i][j] = (1 << N) - 1;
}
 
// 求可选方案的交集
inline int get(int x, int y)
{
    return row[x] &col[y] &cell[x / 3][y / 3];
}
 
inline int get_score(int x, int y)
{
    return min(min(x, 8 - x), min(y, 8 - y)) + 6;
}
 
bool dfs(int cnt, int score)
{
    if (!cnt)
    {
        ans = max(ans, score);
        return false;
    }
 
    // 找出可选方案数最少的空格
    int minv = 10;
    int x, y;
    for (int i = 0; i < N; i++)
        for (int j = 0; j < N; j++)
            if (!g[i][j])
            {
                int t = ones[get(i, j)];
                if (t < minv)
                {
                    minv = t;
                    x = i, y = j;
                }
            }
 
    for (int i = get(x, y); i; i -= lowbit(i))
    {
        int t = map[lowbit(i)];
 
        // 修改状态
        row[x] -= 1 << t;
        col[y] -= 1 << t;
        cell[x / 3][y / 3] -= 1 << t;
        g[x][y] = t + 1;
 
        if (dfs(cnt - 1, score + get_score(x, y) *(t + 1))) return true;
 
        // 恢复现场
        row[x] += 1 << t;
        col[y] += 1 << t;
        cell[x / 3][y / 3] += 1 << t;
        g[x][y] = 0;
    }
 
    return false;
}
 
int main()
{
    for (int i = 0; i < N; i++) map[1 << i] = i;
    for (int i = 0; i < 1 << N; i++)
    {
        int s = 0;
        for (int j = i; j; j -= lowbit(j)) s++;
        ones[i] = s; // i的二进制表示中有s个1
    }
 
    init();
 
    int cnt = 0, score = 0;
    for (int i = 0, k = 0; i < N; i++)
        for (int j = 0; j < N; j++, k++)
        {
            int x;
            scanf("%d", &x);
            g[i][j] = x;
            if (x)
            {
                row[i] -= 1 << x - 1;
                col[j] -= 1 << x - 1;
                cell[i / 3][j / 3] -= 1 << x - 1;
                score += get_score(i, j) *x;
            }
            else cnt++;
        }
 
    dfs(cnt, score);
 
    cout << ans << endl;
 
    return 0;
}


另外,牛客暑期NOIP真题班限时免费报名
报名链接:https://www.nowcoder.com/courses/cover/live/248
报名优惠券:DCYxdCJ