题目链接

平均值

题目描述

给出两个正整数 。每次操作可以选择其中一个数字,然后将其替换为 的几何平均数或平方平均数。问最少经过几次替换,可以使得这两个数相等。

  • 几何平均数: (上取整)
  • 平方平均数: (下取整)

解题思路

这是一个典型的最短路径问题,可以用广度优先搜索 (BFS) 来解决。

1. 问题建模

  • 状态(节点): 每一个可能的数对 都是图中的一个节点。
  • 操作(边): 从一个状态 到另一个状态的四种变换(对 应用两种平均数之一)就是图中的有向边。
  • 目标: 找到从初始状态 到任意目标状态 的最短路径。由于每次操作计为一步,所有边的权重都是 1。

2. BFS 算法

BFS 是求解无权图最短路径的标准算法。

  1. 队列: 我们需要一个队列来存放待处理的状态。队列中的每个元素应包含三个信息:(当前a, 当前b, 到达此状态的步数)

  2. Visited 集合: 为了防止重复搜索和陷入死循环,我们需要一个集合来记录所有已经访问过的状态。

    • 一个重要的细节是,状态 是等价的,因为它们能产生的后续状态完全相同。为了避免冗余,我们在 visited 集合中只存储它们的规范形式,例如 (min(a,b), max(a,b))
  3. 算法流程:

    a. 如果输入的 初始就相等,则答案为 0。

    b. 创建一个队列,将初始状态 (a, b, 0) 入队。

    c. 创建一个 visited 集合,并将 (min(a,b), max(a,b)) 加入。

    d. 当队列不为空时,循环执行:

    i. 出队一个状态 (curr_a, curr_b, dist)。 ii. 计算四种可能的新状态:

    • geom_mean = ceil(sqrt(curr_a * curr_b))

    • quad_mean = floor(sqrt((curr_a^2 + curr_b^2) / 2))

    • 四个新状态分别是: (geom_mean, curr_b), (curr_a, geom_mean), (quad_mean, curr_b), (curr_a, quad_mean)

    iii. 对每个新状态 (next_a, next_b)

    • 检查是否到达终点: 如果 next_a == next_b,则找到了最短路径,返回 dist + 1

    • 检查是否已访问: 将新状态转换为规范形式 (min(next_a, next_b), max(next_a, next_b)),如果在 visited 集合中已存在,则跳过。

    • 入队: 否则,将新状态的规范形式加入 visited 集合,并将 (next_a, next_b, dist + 1) 入队。

代码

#include <iostream>
#include <queue>
#include <vector>
#include <cmath>
#include <unordered_set>
#include <algorithm>
#include <utility>
#include <tuple>

using namespace std;

// Custom hash for std::pair
struct pair_hash {
    template <class T1, class T2>
    std::size_t operator () (const std::pair<T1,T2> &p) const {
        auto h1 = std::hash<T1>{}(p.first);
        auto h2 = std::hash<T2>{}(p.second);
        // A simple way to combine hashes
        return h1 ^ (h2 << 1);
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    long long a, b;
    cin >> a >> b;

    if (a == b) {
        cout << 0 << endl;
        return 0;
    }

    queue<tuple<long long, long long, int>> q;
    q.emplace(a, b, 0);

    unordered_set<pair<long long, long long>, pair_hash> visited;
    visited.insert({min(a, b), max(a, b)});

    while (!q.empty()) {
        auto [cur_a, cur_b, dist] = q.front();
        q.pop();

        long long next_vals[2];
        next_vals[0] = ceil(sqrt((double)cur_a * cur_b));
        next_vals[1] = floor(sqrt(((double)cur_a * cur_a + (double)cur_b * cur_b) / 2.0));

        for (long long val : next_vals) {
            // Option 1: Replace a
            long long next_a1 = val;
            long long next_b1 = cur_b;
            if (next_a1 == next_b1) {
                cout << dist + 1 << endl;
                return 0;
            }
            pair<long long, long long> p1 = {min(next_a1, next_b1), max(next_a1, next_b1)};
            if (visited.find(p1) == visited.end()) {
                visited.insert(p1);
                q.emplace(next_a1, next_b1, dist + 1);
            }
            
            // Option 2: Replace b
            long long next_a2 = cur_a;
            long long next_b2 = val;
            if (next_a2 == next_b2) {
                cout << dist + 1 << endl;
                return 0;
            }
            pair<long long, long long> p2 = {min(next_a2, next_b2), max(next_a2, next_b2)};
            if (visited.find(p2) == visited.end()) {
                visited.insert(p2);
                q.emplace(next_a2, next_b2, dist + 1);
            }
        }
    }

    return 0;
}

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long a = sc.nextLong();
        long b = sc.nextLong();

        if (a == b) {
            System.out.println(0);
            return;
        }

        Queue<long[]> queue = new LinkedList<>();
        queue.offer(new long[]{a, b, 0});

        Set<String> visited = new HashSet<>();
        visited.add(Math.min(a, b) + "," + Math.max(a, b));

        while (!queue.isEmpty()) {
            long[] current = queue.poll();
            long curA = current[0];
            long curB = current[1];
            long dist = current[2];

            long[] nextVals = new long[2];
            nextVals[0] = (long) Math.ceil(Math.sqrt((double) curA * curB));
            nextVals[1] = (long) Math.floor(Math.sqrt(((double) curA * curA + (double) curB * curB) / 2.0));

            for (long val : nextVals) {
                // Option 1: Replace a
                long nextA1 = val;
                long nextB1 = curB;
                if (nextA1 == nextB1) {
                    System.out.println(dist + 1);
                    return;
                }
                String p1 = Math.min(nextA1, nextB1) + "," + Math.max(nextA1, nextB1);
                if (!visited.contains(p1)) {
                    visited.add(p1);
                    queue.offer(new long[]{nextA1, nextB1, dist + 1});
                }

                // Option 2: Replace b
                long nextA2 = curA;
                long nextB2 = val;
                if (nextA2 == nextB2) {
                    System.out.println(dist + 1);
                    return;
                }
                String p2 = Math.min(nextA2, nextB2) + "," + Math.max(nextA2, nextB2);
                if (!visited.contains(p2)) {
                    visited.add(p2);
                    queue.offer(new long[]{nextA2, nextB2, dist + 1});
                }
            }
        }
    }
}

import sys
import math
from collections import deque

def solve():
    try:
        a, b = map(int, sys.stdin.readline().split())
    except (IOError, ValueError):
        return

    if a == b:
        print(0)
        return

    queue = deque([(a, b, 0)])
    visited = {(min(a, b), max(a, b))}

    while queue:
        cur_a, cur_b, dist = queue.popleft()

        geom_mean = math.ceil(math.sqrt(cur_a * cur_b))
        quad_mean = math.floor(math.sqrt((cur_a**2 + cur_b**2) / 2))

        next_vals = [geom_mean, quad_mean]

        for val in next_vals:
            # Option 1: Replace a
            next_a1, next_b1 = val, cur_b
            if next_a1 == next_b1:
                print(dist + 1)
                return
            
            p1 = tuple(sorted((next_a1, next_b1)))
            if p1 not in visited:
                visited.add(p1)
                queue.append((next_a1, next_b1, dist + 1))

            # Option 2: Replace b
            next_a2, next_b2 = cur_a, val
            if next_a2 == next_b2:
                print(dist + 1)
                return
            
            p2 = tuple(sorted((next_a2, next_b2)))
            if p2 not in visited:
                visited.add(p2)
                queue.append((next_a2, next_b2, dist + 1))

solve()

算法及复杂度

  • 算法:广度优先搜索 (BFS)

  • 时间复杂度,其中 是从初始状态可达的状态(数对)总数。由于每个状态最多产生 4 个新状态,总的边数与顶点数呈线性关系。状态空间的大小取决于数值收敛的速度,难以精确给出,但对于题目数据范围是可接受的。

  • 空间复杂度,主要用于存储队列和 visited 集合。