题目链接

八数码

题目描述

在一个 3x3 的方格中,有数字1到8和'x'(代表空格)共九个元素。每次操作可以将空格'x'与其上下左右相邻的数字进行交换。给定一个初始状态,你需要找出将其恢复为目标状态 12345678x 所需的最少交换次数。

解题思路

  1. 问题建模:状态图搜索

    我们可以将八数碼的每一种可能的排列看作一个状态,每次交换操作看作是状态之间的转移。这样,整个问题就转化为了在一个巨大的状态图中,寻找从初始状态目标状态最短路径

    求解无权图的最短路径问题,最经典的算法就是广度优先搜索 (BFS)

  2. 朴素 BFS

    一个直接的想法是从初始状态开始进行 BFS。

    • 状态表示:用一个长度为9的字符串来表示 3x3 的棋盘状态。
    • 队列:队列中存储 (状态字符串, 到达该状态的步数)
    • 判重:用一个哈希集合(setmap)来记录已经访问过的状态,防止重复搜索和陷入死循环。
    • 流程:从初始状态入队,每次取出队首状态,生成所有下一步可能的状态,如果新状态未被访问过,则将其入队并标记为已访问。直到找到目标状态为止。
  3. 优化:双向广度优先搜索 (Bi-directional BFS)

    朴素BFS的搜索范围像一个不断扩大的圆,其面积(搜索的节点数)与半径(搜索深度 d)呈指数关系增长(b是分支因子)。

    双向BFS通过同时从起点终点进行搜索来极大地优化这个过程。它像是两个相向而行的、不断扩大的圆。当两个圆相遇时,搜索就结束了。

    • 效率:两个半径为 d/2 的小圆的面积之和远小于一个半径为 d 的大圆的面积()。这使得搜索效率得到巨大提升。
    • 实现
      1. 需要两个队列,q_startq_end,分别用于从起点和终点开始的搜索。
      2. 需要两个哈希表,dist_startdist_end,用于记录从起点和终点到达某个状态的最短距离,并兼具判重功能。
      3. 将初始状态和目标状态分别加入各自的队列和哈希表。
      4. 在主循环中,为了保持两个搜索“圆”的大小相当,每次选择当前元素较少的那个队列进行扩展。
      5. 当从一个方向(如起点)扩展出一个新状态 next_state 时,检查这个状态是否在另一个方向的哈希表 dist_end 中出现过。
      6. 如果出现过,说明两条搜索路径在此相遇了。总的最短路径长度就是 dist_start[current_state] + 1 + dist_end[next_state]。搜索结束。
      7. 如果没有相遇,且 next_state 未被当前方向访问过,则将其加入当前方向的队列和哈希表。
  4. 关于可解性

    不是所有的八数码初始状态都能恢复到目标状态。一个状态是否可解,取决于其逆序对的数量。逆序对是指将 3x3 矩阵按行读成一维序列后,所有数字(不含空格)中,排在前面的数大于排在后面的数的对数。对于八数码问题,只有逆序对数量为偶数的状态才是可解的。本题的测试数据保证了所有初始状态都是可解的。

代码

#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <map>
#include <algorithm>

using namespace std;

string get_initial_state() {
    string s = "";
    char c;
    for (int i = 0; i < 9; ++i) {
        cin >> c;
        if (c != ' ' && c != '\n') {
            s += c;
        }
    }
    return s;
}

int bfs(const string& start_state, const string& end_state) {
    if (start_state == end_state) return 0;

    queue<string> q_start, q_end;
    map<string, int> dist_start, dist_end;

    q_start.push(start_state);
    dist_start[start_state] = 0;
    q_end.push(end_state);
    dist_end[end_state] = 0;

    int dr[] = {-1, 1, 0, 0}; // Up, Down, Left, Right
    int dc[] = {0, 0, -1, 1};

    while (!q_start.empty() && !q_end.empty()) {
        // 优先扩展较小的队列
        if (q_start.size() <= q_end.size()) {
            string current = q_start.front();
            q_start.pop();
            int dist = dist_start[current];

            int x_pos = current.find('x');
            int r = x_pos / 3;
            int c = x_pos % 3;

            for (int i = 0; i < 4; ++i) {
                int nr = r + dr[i];
                int nc = c + dc[i];

                if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
                    string next = current;
                    swap(next[x_pos], next[nr * 3 + nc]);

                    if (dist_end.count(next)) {
                        return dist + 1 + dist_end[next];
                    }
                    if (!dist_start.count(next)) {
                        dist_start[next] = dist + 1;
                        q_start.push(next);
                    }
                }
            }
        } else {
            string current = q_end.front();
            q_end.pop();
            int dist = dist_end[current];

            int x_pos = current.find('x');
            int r = x_pos / 3;
            int c = x_pos % 3;

            for (int i = 0; i < 4; ++i) {
                int nr = r + dr[i];
                int nc = c + dc[i];

                if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
                    string next = current;
                    swap(next[x_pos], next[nr * 3 + nc]);

                    if (dist_start.count(next)) {
                        return dist + 1 + dist_start[next];
                    }
                    if (!dist_end.count(next)) {
                        dist_end[next] = dist + 1;
                        q_end.push(next);
                    }
                }
            }
        }
    }
    return -1; // Should not be reached for solvable puzzles
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    
    string start_state = "";
    char val;
    for(int i = 0; i < 9; ++i) {
        cin >> val;
        start_state += val;
    }

    string end_state = "12345678x";

    cout << bfs(start_state, end_state) << endl;

    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < 9; i++) {
            sb.append(sc.next());
        }
        String startState = sb.toString();
        String endState = "12345678x";

        System.out.println(bfs(startState, endState));
    }

    private static int bfs(String start, String end) {
        if (start.equals(end)) return 0;

        Queue<String> qStart = new LinkedList<>();
        Map<String, Integer> distStart = new HashMap<>();
        Queue<String> qEnd = new LinkedList<>();
        Map<String, Integer> distEnd = new HashMap<>();

        qStart.add(start);
        distStart.put(start, 0);
        qEnd.add(end);
        distEnd.put(end, 0);

        int[] dr = {-1, 1, 0, 0};
        int[] dc = {0, 0, -1, 1};

        while (!qStart.isEmpty() && !qEnd.isEmpty()) {
            if (qStart.size() <= qEnd.size()) {
                int size = qStart.size();
                for (int i = 0; i < size; i++) {
                    String current = qStart.poll();
                    int dist = distStart.get(current);
                    int xPos = current.indexOf('x');
                    int r = xPos / 3, c = xPos % 3;

                    for (int j = 0; j < 4; j++) {
                        int nr = r + dr[j];
                        int nc = c + dc[j];
                        if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
                            String next = swap(current, xPos, nr * 3 + nc);
                            if (distEnd.containsKey(next)) {
                                return dist + 1 + distEnd.get(next);
                            }
                            if (!distStart.containsKey(next)) {
                                distStart.put(next, dist + 1);
                                qStart.add(next);
                            }
                        }
                    }
                }
            } else {
                int size = qEnd.size();
                for (int i = 0; i < size; i++) {
                    String current = qEnd.poll();
                    int dist = distEnd.get(current);
                    int xPos = current.indexOf('x');
                    int r = xPos / 3, c = xPos % 3;

                    for (int j = 0; j < 4; j++) {
                        int nr = r + dr[j];
                        int nc = c + dc[j];
                        if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
                            String next = swap(current, xPos, nr * 3 + nc);
                            if (distStart.containsKey(next)) {
                                return dist + 1 + distStart.get(next);
                            }
                            if (!distEnd.containsKey(next)) {
                                distEnd.put(next, dist + 1);
                                qEnd.add(next);
                            }
                        }
                    }
                }
            }
        }
        return -1; // Not found
    }

    private static String swap(String s, int i, int j) {
        char[] arr = s.toCharArray();
        char temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
        return new String(arr);
    }
}
import sys
from collections import deque

def bfs(start_state, end_state):
    if start_state == end_state:
        return 0

    q_start = deque([(start_state, 0)])
    dist_start = {start_state: 0}
    q_end = deque([(end_state, 0)])
    dist_end = {end_state: 0}

    dr = [-1, 1, 0, 0]
    dc = [0, 0, -1, 1]

    while q_start and q_end:
        if len(q_start) <= len(q_end):
            current, dist = q_start.popleft()
            
            x_pos = current.find('x')
            r, c = x_pos // 3, x_pos % 3

            for i in range(4):
                nr, nc = r + dr[i], c + dc[i]
                if 0 <= nr < 3 and 0 <= nc < 3:
                    new_pos = nr * 3 + nc
                    next_state_list = list(current)
                    next_state_list[x_pos], next_state_list[new_pos] = next_state_list[new_pos], next_state_list[x_pos]
                    next_state = "".join(next_state_list)

                    if next_state in dist_end:
                        return dist + 1 + dist_end[next_state]
                    if next_state not in dist_start:
                        dist_start[next_state] = dist + 1
                        q_start.append((next_state, dist + 1))
        else:
            current, dist = q_end.popleft()
            
            x_pos = current.find('x')
            r, c = x_pos // 3, x_pos % 3

            for i in range(4):
                nr, nc = r + dr[i], c + dc[i]
                if 0 <= nr < 3 and 0 <= nc < 3:
                    new_pos = nr * 3 + nc
                    next_state_list = list(current)
                    next_state_list[x_pos], next_state_list[new_pos] = next_state_list[new_pos], next_state_list[x_pos]
                    next_state = "".join(next_state_list)

                    if next_state in dist_start:
                        return dist + 1 + dist_start[next_state]
                    if next_state not in dist_end:
                        dist_end[next_state] = dist + 1
                        q_end.append((next_state, dist + 1))
    return -1


def solve():
    input_lines = sys.stdin.read().split()
    start_state = "".join(input_lines)
    end_state = "12345678x"
    
    print(bfs(start_state, end_state))

solve()

算法及复杂度

  • 算法:双向广度优先搜索 (Bi-directional BFS)
  • 时间复杂度,其中 是分支因子(每个状态平均约有2-3个后继状态), 是解的深度(最短路径长度)。八数码问题的可达状态总数为 个,双向BFS需要探索的状态数远小于此。
  • 空间复杂度,用于存储两个方向的队列和哈希表。