题目链接
题目描述
在一个 3x3 的方格中,有数字1到8和'x'(代表空格)共九个元素。每次操作可以将空格'x'与其上下左右相邻的数字进行交换。给定一个初始状态,你需要找出将其恢复为目标状态 12345678x
所需的最少交换次数。
解题思路
-
问题建模:状态图搜索
我们可以将八数碼的每一种可能的排列看作一个状态,每次交换操作看作是状态之间的转移。这样,整个问题就转化为了在一个巨大的状态图中,寻找从初始状态到目标状态的最短路径。
求解无权图的最短路径问题,最经典的算法就是广度优先搜索 (BFS)。
-
朴素 BFS
一个直接的想法是从初始状态开始进行 BFS。
- 状态表示:用一个长度为9的字符串来表示 3x3 的棋盘状态。
- 队列:队列中存储
(状态字符串, 到达该状态的步数)
。 - 判重:用一个哈希集合(
set
或map
)来记录已经访问过的状态,防止重复搜索和陷入死循环。 - 流程:从初始状态入队,每次取出队首状态,生成所有下一步可能的状态,如果新状态未被访问过,则将其入队并标记为已访问。直到找到目标状态为止。
-
优化:双向广度优先搜索 (Bi-directional BFS)
朴素BFS的搜索范围像一个不断扩大的圆,其面积(搜索的节点数)与半径(搜索深度
d
)呈指数关系增长(,
b
是分支因子)。双向BFS通过同时从起点和终点进行搜索来极大地优化这个过程。它像是两个相向而行的、不断扩大的圆。当两个圆相遇时,搜索就结束了。
- 效率:两个半径为
d/2
的小圆的面积之和远小于一个半径为d
的大圆的面积()。这使得搜索效率得到巨大提升。
- 实现:
- 需要两个队列,
q_start
和q_end
,分别用于从起点和终点开始的搜索。 - 需要两个哈希表,
dist_start
和dist_end
,用于记录从起点和终点到达某个状态的最短距离,并兼具判重功能。 - 将初始状态和目标状态分别加入各自的队列和哈希表。
- 在主循环中,为了保持两个搜索“圆”的大小相当,每次选择当前元素较少的那个队列进行扩展。
- 当从一个方向(如起点)扩展出一个新状态
next_state
时,检查这个状态是否在另一个方向的哈希表dist_end
中出现过。 - 如果出现过,说明两条搜索路径在此相遇了。总的最短路径长度就是
dist_start[current_state] + 1 + dist_end[next_state]
。搜索结束。 - 如果没有相遇,且
next_state
未被当前方向访问过,则将其加入当前方向的队列和哈希表。
- 需要两个队列,
- 效率:两个半径为
-
关于可解性
不是所有的八数码初始状态都能恢复到目标状态。一个状态是否可解,取决于其逆序对的数量。逆序对是指将 3x3 矩阵按行读成一维序列后,所有数字(不含空格)中,排在前面的数大于排在后面的数的对数。对于八数码问题,只有逆序对数量为偶数的状态才是可解的。本题的测试数据保证了所有初始状态都是可解的。
代码
#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <map>
#include <algorithm>
using namespace std;
string get_initial_state() {
string s = "";
char c;
for (int i = 0; i < 9; ++i) {
cin >> c;
if (c != ' ' && c != '\n') {
s += c;
}
}
return s;
}
int bfs(const string& start_state, const string& end_state) {
if (start_state == end_state) return 0;
queue<string> q_start, q_end;
map<string, int> dist_start, dist_end;
q_start.push(start_state);
dist_start[start_state] = 0;
q_end.push(end_state);
dist_end[end_state] = 0;
int dr[] = {-1, 1, 0, 0}; // Up, Down, Left, Right
int dc[] = {0, 0, -1, 1};
while (!q_start.empty() && !q_end.empty()) {
// 优先扩展较小的队列
if (q_start.size() <= q_end.size()) {
string current = q_start.front();
q_start.pop();
int dist = dist_start[current];
int x_pos = current.find('x');
int r = x_pos / 3;
int c = x_pos % 3;
for (int i = 0; i < 4; ++i) {
int nr = r + dr[i];
int nc = c + dc[i];
if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
string next = current;
swap(next[x_pos], next[nr * 3 + nc]);
if (dist_end.count(next)) {
return dist + 1 + dist_end[next];
}
if (!dist_start.count(next)) {
dist_start[next] = dist + 1;
q_start.push(next);
}
}
}
} else {
string current = q_end.front();
q_end.pop();
int dist = dist_end[current];
int x_pos = current.find('x');
int r = x_pos / 3;
int c = x_pos % 3;
for (int i = 0; i < 4; ++i) {
int nr = r + dr[i];
int nc = c + dc[i];
if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
string next = current;
swap(next[x_pos], next[nr * 3 + nc]);
if (dist_start.count(next)) {
return dist + 1 + dist_start[next];
}
if (!dist_end.count(next)) {
dist_end[next] = dist + 1;
q_end.push(next);
}
}
}
}
}
return -1; // Should not be reached for solvable puzzles
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
string start_state = "";
char val;
for(int i = 0; i < 9; ++i) {
cin >> val;
start_state += val;
}
string end_state = "12345678x";
cout << bfs(start_state, end_state) << endl;
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 9; i++) {
sb.append(sc.next());
}
String startState = sb.toString();
String endState = "12345678x";
System.out.println(bfs(startState, endState));
}
private static int bfs(String start, String end) {
if (start.equals(end)) return 0;
Queue<String> qStart = new LinkedList<>();
Map<String, Integer> distStart = new HashMap<>();
Queue<String> qEnd = new LinkedList<>();
Map<String, Integer> distEnd = new HashMap<>();
qStart.add(start);
distStart.put(start, 0);
qEnd.add(end);
distEnd.put(end, 0);
int[] dr = {-1, 1, 0, 0};
int[] dc = {0, 0, -1, 1};
while (!qStart.isEmpty() && !qEnd.isEmpty()) {
if (qStart.size() <= qEnd.size()) {
int size = qStart.size();
for (int i = 0; i < size; i++) {
String current = qStart.poll();
int dist = distStart.get(current);
int xPos = current.indexOf('x');
int r = xPos / 3, c = xPos % 3;
for (int j = 0; j < 4; j++) {
int nr = r + dr[j];
int nc = c + dc[j];
if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
String next = swap(current, xPos, nr * 3 + nc);
if (distEnd.containsKey(next)) {
return dist + 1 + distEnd.get(next);
}
if (!distStart.containsKey(next)) {
distStart.put(next, dist + 1);
qStart.add(next);
}
}
}
}
} else {
int size = qEnd.size();
for (int i = 0; i < size; i++) {
String current = qEnd.poll();
int dist = distEnd.get(current);
int xPos = current.indexOf('x');
int r = xPos / 3, c = xPos % 3;
for (int j = 0; j < 4; j++) {
int nr = r + dr[j];
int nc = c + dc[j];
if (nr >= 0 && nr < 3 && nc >= 0 && nc < 3) {
String next = swap(current, xPos, nr * 3 + nc);
if (distStart.containsKey(next)) {
return dist + 1 + distStart.get(next);
}
if (!distEnd.containsKey(next)) {
distEnd.put(next, dist + 1);
qEnd.add(next);
}
}
}
}
}
}
return -1; // Not found
}
private static String swap(String s, int i, int j) {
char[] arr = s.toCharArray();
char temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
return new String(arr);
}
}
import sys
from collections import deque
def bfs(start_state, end_state):
if start_state == end_state:
return 0
q_start = deque([(start_state, 0)])
dist_start = {start_state: 0}
q_end = deque([(end_state, 0)])
dist_end = {end_state: 0}
dr = [-1, 1, 0, 0]
dc = [0, 0, -1, 1]
while q_start and q_end:
if len(q_start) <= len(q_end):
current, dist = q_start.popleft()
x_pos = current.find('x')
r, c = x_pos // 3, x_pos % 3
for i in range(4):
nr, nc = r + dr[i], c + dc[i]
if 0 <= nr < 3 and 0 <= nc < 3:
new_pos = nr * 3 + nc
next_state_list = list(current)
next_state_list[x_pos], next_state_list[new_pos] = next_state_list[new_pos], next_state_list[x_pos]
next_state = "".join(next_state_list)
if next_state in dist_end:
return dist + 1 + dist_end[next_state]
if next_state not in dist_start:
dist_start[next_state] = dist + 1
q_start.append((next_state, dist + 1))
else:
current, dist = q_end.popleft()
x_pos = current.find('x')
r, c = x_pos // 3, x_pos % 3
for i in range(4):
nr, nc = r + dr[i], c + dc[i]
if 0 <= nr < 3 and 0 <= nc < 3:
new_pos = nr * 3 + nc
next_state_list = list(current)
next_state_list[x_pos], next_state_list[new_pos] = next_state_list[new_pos], next_state_list[x_pos]
next_state = "".join(next_state_list)
if next_state in dist_start:
return dist + 1 + dist_start[next_state]
if next_state not in dist_end:
dist_end[next_state] = dist + 1
q_end.append((next_state, dist + 1))
return -1
def solve():
input_lines = sys.stdin.read().split()
start_state = "".join(input_lines)
end_state = "12345678x"
print(bfs(start_state, end_state))
solve()
算法及复杂度
- 算法:双向广度优先搜索 (Bi-directional BFS)
- 时间复杂度:
,其中
是分支因子(每个状态平均约有2-3个后继状态),
是解的深度(最短路径长度)。八数码问题的可达状态总数为
个,双向BFS需要探索的状态数远小于此。
- 空间复杂度:
,用于存储两个方向的队列和哈希表。