题目链接
题目描述
给出两个正整数 和
。每次操作可以选择其中一个数字,然后将其替换为
的几何平均数或平方平均数。问最少经过几次替换,可以使得这两个数相等。
- 几何平均数:
(上取整)
- 平方平均数:
(下取整)
解题思路
这是一个典型的最短路径问题,可以用广度优先搜索 (BFS) 来解决。
1. 问题建模
- 状态(节点): 每一个可能的数对
都是图中的一个节点。
- 操作(边): 从一个状态
到另一个状态的四种变换(对
或
应用两种平均数之一)就是图中的有向边。
- 目标: 找到从初始状态
到任意目标状态
的最短路径。由于每次操作计为一步,所有边的权重都是 1。
2. BFS 算法
BFS 是求解无权图最短路径的标准算法。
-
队列: 我们需要一个队列来存放待处理的状态。队列中的每个元素应包含三个信息:
(当前a, 当前b, 到达此状态的步数)
。 -
Visited 集合: 为了防止重复搜索和陷入死循环,我们需要一个集合来记录所有已经访问过的状态。
- 一个重要的细节是,状态
和
是等价的,因为它们能产生的后续状态完全相同。为了避免冗余,我们在
visited
集合中只存储它们的规范形式,例如(min(a,b), max(a,b))
。
- 一个重要的细节是,状态
-
算法流程:
a. 如果输入的
和
初始就相等,则答案为 0。
b. 创建一个队列,将初始状态
(a, b, 0)
入队。c. 创建一个
visited
集合,并将(min(a,b), max(a,b))
加入。d. 当队列不为空时,循环执行:
i. 出队一个状态
(curr_a, curr_b, dist)
。 ii. 计算四种可能的新状态:-
geom_mean = ceil(sqrt(curr_a * curr_b))
-
quad_mean = floor(sqrt((curr_a^2 + curr_b^2) / 2))
-
四个新状态分别是:
(geom_mean, curr_b)
,(curr_a, geom_mean)
,(quad_mean, curr_b)
,(curr_a, quad_mean)
。
iii. 对每个新状态
(next_a, next_b)
:-
检查是否到达终点: 如果
next_a == next_b
,则找到了最短路径,返回dist + 1
。 -
检查是否已访问: 将新状态转换为规范形式
(min(next_a, next_b), max(next_a, next_b))
,如果在visited
集合中已存在,则跳过。 -
入队: 否则,将新状态的规范形式加入
visited
集合,并将(next_a, next_b, dist + 1)
入队。
-
代码
#include <iostream>
#include <queue>
#include <vector>
#include <cmath>
#include <unordered_set>
#include <algorithm>
#include <utility>
#include <tuple>
using namespace std;
// Custom hash for std::pair
struct pair_hash {
template <class T1, class T2>
std::size_t operator () (const std::pair<T1,T2> &p) const {
auto h1 = std::hash<T1>{}(p.first);
auto h2 = std::hash<T2>{}(p.second);
// A simple way to combine hashes
return h1 ^ (h2 << 1);
}
};
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
long long a, b;
cin >> a >> b;
if (a == b) {
cout << 0 << endl;
return 0;
}
queue<tuple<long long, long long, int>> q;
q.emplace(a, b, 0);
unordered_set<pair<long long, long long>, pair_hash> visited;
visited.insert({min(a, b), max(a, b)});
while (!q.empty()) {
auto [cur_a, cur_b, dist] = q.front();
q.pop();
long long next_vals[2];
next_vals[0] = ceil(sqrt((double)cur_a * cur_b));
next_vals[1] = floor(sqrt(((double)cur_a * cur_a + (double)cur_b * cur_b) / 2.0));
for (long long val : next_vals) {
// Option 1: Replace a
long long next_a1 = val;
long long next_b1 = cur_b;
if (next_a1 == next_b1) {
cout << dist + 1 << endl;
return 0;
}
pair<long long, long long> p1 = {min(next_a1, next_b1), max(next_a1, next_b1)};
if (visited.find(p1) == visited.end()) {
visited.insert(p1);
q.emplace(next_a1, next_b1, dist + 1);
}
// Option 2: Replace b
long long next_a2 = cur_a;
long long next_b2 = val;
if (next_a2 == next_b2) {
cout << dist + 1 << endl;
return 0;
}
pair<long long, long long> p2 = {min(next_a2, next_b2), max(next_a2, next_b2)};
if (visited.find(p2) == visited.end()) {
visited.insert(p2);
q.emplace(next_a2, next_b2, dist + 1);
}
}
}
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
long a = sc.nextLong();
long b = sc.nextLong();
if (a == b) {
System.out.println(0);
return;
}
Queue<long[]> queue = new LinkedList<>();
queue.offer(new long[]{a, b, 0});
Set<String> visited = new HashSet<>();
visited.add(Math.min(a, b) + "," + Math.max(a, b));
while (!queue.isEmpty()) {
long[] current = queue.poll();
long curA = current[0];
long curB = current[1];
long dist = current[2];
long[] nextVals = new long[2];
nextVals[0] = (long) Math.ceil(Math.sqrt((double) curA * curB));
nextVals[1] = (long) Math.floor(Math.sqrt(((double) curA * curA + (double) curB * curB) / 2.0));
for (long val : nextVals) {
// Option 1: Replace a
long nextA1 = val;
long nextB1 = curB;
if (nextA1 == nextB1) {
System.out.println(dist + 1);
return;
}
String p1 = Math.min(nextA1, nextB1) + "," + Math.max(nextA1, nextB1);
if (!visited.contains(p1)) {
visited.add(p1);
queue.offer(new long[]{nextA1, nextB1, dist + 1});
}
// Option 2: Replace b
long nextA2 = curA;
long nextB2 = val;
if (nextA2 == nextB2) {
System.out.println(dist + 1);
return;
}
String p2 = Math.min(nextA2, nextB2) + "," + Math.max(nextA2, nextB2);
if (!visited.contains(p2)) {
visited.add(p2);
queue.offer(new long[]{nextA2, nextB2, dist + 1});
}
}
}
}
}
import sys
import math
from collections import deque
def solve():
try:
a, b = map(int, sys.stdin.readline().split())
except (IOError, ValueError):
return
if a == b:
print(0)
return
queue = deque([(a, b, 0)])
visited = {(min(a, b), max(a, b))}
while queue:
cur_a, cur_b, dist = queue.popleft()
geom_mean = math.ceil(math.sqrt(cur_a * cur_b))
quad_mean = math.floor(math.sqrt((cur_a**2 + cur_b**2) / 2))
next_vals = [geom_mean, quad_mean]
for val in next_vals:
# Option 1: Replace a
next_a1, next_b1 = val, cur_b
if next_a1 == next_b1:
print(dist + 1)
return
p1 = tuple(sorted((next_a1, next_b1)))
if p1 not in visited:
visited.add(p1)
queue.append((next_a1, next_b1, dist + 1))
# Option 2: Replace b
next_a2, next_b2 = cur_a, val
if next_a2 == next_b2:
print(dist + 1)
return
p2 = tuple(sorted((next_a2, next_b2)))
if p2 not in visited:
visited.add(p2)
queue.append((next_a2, next_b2, dist + 1))
solve()
算法及复杂度
-
算法:广度优先搜索 (BFS)
-
时间复杂度:
,其中
是从初始状态可达的状态(数对)总数。由于每个状态最多产生 4 个新状态,总的边数与顶点数呈线性关系。状态空间的大小取决于数值收敛的速度,难以精确给出,但对于题目数据范围是可接受的。
-
空间复杂度:
,主要用于存储队列和
visited
集合。