平均值

[题目链接](https://www.nowcoder.com/practice/c51ec17cd3034374be2b95d73ef46a1a)

思路

给定两个正整数 ,每次操作可以选择其中一个数,将其替换为:

  • 几何平均值(上取整)
  • 平方平均值(下取整)

求使两个数相等的最少操作次数。

关键观察

两种平均值都在 之间(或非常接近),因此每次操作都在缩小两个数的差距。这意味着:

  1. 状态空间很小——从任意一对 出发,经过很少的步数两个数就会收敛到相等。
  2. 由于 本质相同,我们可以将状态标准化为 ,减少一半搜索空间。

BFS 求最短路

将每个状态 (其中 )看作图中的一个节点。从初始状态出发,每步有 4 种转移:

  • 用几何平均值替换
  • 用平方平均值替换

用 BFS 搜索即可得到最短操作次数。由于值快速收敛,实际访问的状态非常少。

整数精度处理

浮点 sqrt 可能有精度误差,需要对结果做微调:

  • 几何平均值(上取整):先算 ,然后向下/向上调整直到 是满足 的最小整数。
  • 平方平均值(下取整):先算 ,然后调整直到

代码

[sol-C++]

#include <bits/stdc++.h>
using namespace std;

int main(){
    long long a, b;
    scanf("%lld%lld", &a, &b);
    if(a == b){
        printf("0\n");
        return 0;
    }
    map<pair<long long,long long>, int> dist;
    queue<pair<long long,long long>> q;
    auto norm = [](long long x, long long y) -> pair<long long,long long> {
        return {min(x,y), max(x,y)};
    };
    auto start = norm(a, b);
    dist[start] = 0;
    q.push(start);
    while(!q.empty()){
        auto [x, y] = q.front(); q.pop();
        int d = dist[{x,y}];
        // 几何平均值 ceil(sqrt(x*y))
        long long gm = (long long)ceil(sqrt((double)x * y));
        while(gm * gm < x * y) gm++;
        while((gm-1)*(gm-1) >= x * y && gm > 0) gm--;
        // 平方平均值 floor(sqrt((x^2+y^2)/2))
        long long qm = (long long)sqrt(((double)x*x + (double)y*y) / 2.0);
        while((2*qm*qm) > (x*x + y*y)) qm--;
        while((2*(qm+1)*(qm+1)) <= (x*x + y*y)) qm++;

        long long nexts[] = {gm, qm};
        for(long long nv : nexts){
            auto s1 = norm(nv, y);
            if(dist.find(s1) == dist.end()){
                dist[s1] = d + 1;
                if(s1.first == s1.second){ printf("%d\n", d+1); return 0; }
                q.push(s1);
            }
            auto s2 = norm(x, nv);
            if(dist.find(s2) == dist.end()){
                dist[s2] = d + 1;
                if(s2.first == s2.second){ printf("%d\n", d+1); return 0; }
                q.push(s2);
            }
        }
    }
    printf("-1\n");
    return 0;
}

[sol-Java]

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long a = sc.nextLong(), b = sc.nextLong();
        if (a == b) { System.out.println(0); return; }

        Map<String, Integer> dist = new HashMap<>();
        Queue<long[]> queue = new LinkedList<>();
        long x0 = Math.min(a, b), y0 = Math.max(a, b);
        String startKey = x0 + "," + y0;
        dist.put(startKey, 0);
        queue.add(new long[]{x0, y0});

        while (!queue.isEmpty()) {
            long[] cur = queue.poll();
            long x = cur[0], y = cur[1];
            int d = dist.get(x + "," + y);

            long gm = (long) Math.ceil(Math.sqrt((double) x * y));
            while (gm * gm < x * y) gm++;
            while (gm > 0 && (gm - 1) * (gm - 1) >= x * y) gm--;

            long s = x * x + y * y;
            long qm = (long) Math.sqrt((double) s / 2.0);
            while (2 * qm * qm > s) qm--;
            while (2 * (qm + 1) * (qm + 1) <= s) qm++;

            long[] nexts = {gm, qm};
            for (long nv : nexts) {
                long nx1 = Math.min(nv, y), ny1 = Math.max(nv, y);
                String k1 = nx1 + "," + ny1;
                if (!dist.containsKey(k1)) {
                    dist.put(k1, d + 1);
                    if (nx1 == ny1) { System.out.println(d + 1); return; }
                    queue.add(new long[]{nx1, ny1});
                }
                long nx2 = Math.min(x, nv), ny2 = Math.max(x, nv);
                String k2 = nx2 + "," + ny2;
                if (!dist.containsKey(k2)) {
                    dist.put(k2, d + 1);
                    if (nx2 == ny2) { System.out.println(d + 1); return; }
                    queue.add(new long[]{nx2, ny2});
                }
            }
        }
        System.out.println(-1);
    }
}

[sol-Python3]

import math
from collections import deque

def solve():
    a, b = map(int, input().split())
    if a == b:
        print(0)
        return

    def norm(x, y):
        return (min(x, y), max(x, y))

    start = norm(a, b)
    dist = {start: 0}
    q = deque([start])

    while q:
        x, y = q.popleft()
        d = dist[(x, y)]

        # 几何平均值 ceil(sqrt(x*y))
        prod = x * y
        gm = math.isqrt(prod)
        if gm * gm < prod:
            gm += 1

        # 平方平均值 floor(sqrt((x^2+y^2)/2))
        s = x * x + y * y
        qm = math.isqrt(s // 2)
        while 2 * qm * qm > s:
            qm -= 1
        while 2 * (qm + 1) * (qm + 1) <= s:
            qm += 1

        for nv in (gm, qm):
            s1 = norm(nv, y)
            if s1 not in dist:
                dist[s1] = d + 1
                if s1[0] == s1[1]:
                    print(d + 1)
                    return
                q.append(s1)
            s2 = norm(x, nv)
            if s2 not in dist:
                dist[s2] = d + 1
                if s2[0] == s2[1]:
                    print(d + 1)
                    return
                q.append(s2)

    print(-1)

solve()

[sol-JavaScript]

const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });

rl.on('line', (line) => {
    const parts = line.trim().split(/\s+/);
    const a = BigInt(parts[0]), b = BigInt(parts[1]);
    if (a === b) { console.log(0); rl.close(); return; }

    function norm(x, y) { return x < y ? [x, y] : [y, x]; }
    function key(x, y) { return x.toString() + ',' + y.toString(); }
    function isqrt(n) {
        if (n <= 0n) return 0n;
        let x = BigInt(Math.floor(Math.sqrt(Number(n))));
        while (x * x > n) x--;
        while ((x + 1n) * (x + 1n) <= n) x++;
        return x;
    }

    const [x0, y0] = norm(a, b);
    const dist = new Map();
    dist.set(key(x0, y0), 0);
    const queue = [[x0, y0]];
    let head = 0;

    while (head < queue.length) {
        const [x, y] = queue[head++];
        const d = dist.get(key(x, y));

        const prod = x * y;
        let gm = isqrt(prod);
        if (gm * gm < prod) gm++;

        const s = x * x + y * y;
        let qm = isqrt(s / 2n);
        while (2n * qm * qm > s) qm--;
        while (2n * (qm + 1n) * (qm + 1n) <= s) qm++;

        for (const nv of [gm, qm]) {
            const [nx1, ny1] = norm(nv, y);
            const k1 = key(nx1, ny1);
            if (!dist.has(k1)) {
                dist.set(k1, d + 1);
                if (nx1 === ny1) { console.log(d + 1); rl.close(); return; }
                queue.push([nx1, ny1]);
            }
            const [nx2, ny2] = norm(x, nv);
            const k2 = key(nx2, ny2);
            if (!dist.has(k2)) {
                dist.set(k2, d + 1);
                if (nx2 === ny2) { console.log(d + 1); rl.close(); return; }
                queue.push([nx2, ny2]);
            }
        }
    }
    console.log(-1);
    rl.close();
});

复杂度分析

  • 时间复杂度,其中 是 BFS 访问的状态数。由于几何平均和平方平均都使两个数快速收敛, 非常小(实测最多几十个状态),map 查找带 因子。
  • 空间复杂度,存储 BFS 队列和距离表。