公因数排序

[题目链接](https://www.nowcoder.com/practice/7763486cf5b34e738b9b2a1095fbe4af)

思路

题意很直白:如果两个数有大于 1 的公因数,就可以交换它们的位置。问能不能通过若干次这种交换把数组排成升序。

直觉上,"能交换"这件事是可以传递的。比如 能交换, 能交换,那 最终也能到 的位置上——因为你可以通过多步交换来中转。这不就是并查集的经典场景吗?

关键观察:用质因数建图

两个数有公因数,等价于它们共享至少一个质因子。所以我们可以对每个数做质因数分解,然后把这个数和它的每个质因子在并查集里合并。这样一来,所有共享某个质因子的数就自动连通了。

举个例子:6 的质因子是 2 和 3,4 的质因子是 2。因为 6 和 4 都和 2 合并了,所以 6 和 4 在同一个连通分量里,它们可以交换。

判定是否可排序

建好并查集后,把原数组排个序。对每个位置 ,如果原数组的 和排序后的 不同,就检查它们是否在同一个连通分量里。如果某个位置不满足,说明那个元素没法通过交换到达正确位置,答案就是 No。

特判

值为 0 或 1 的元素没有大于 1 的因子,不能和任何其他数交换。如果它本身不在正确位置上,直接判 No。

质因数分解的优化

对于 Python 这类较慢的语言,直接试除分解可能超时。可以预处理一个最小质因子(SPF)筛,之后每个数的分解只需不断除以最小质因子即可,速度快很多。

代码

#include <bits/stdc++.h>
using namespace std;

int parent[1000001];
int rnk[1000001];

int find(int x) {
    while (parent[x] != x) {
        parent[x] = parent[parent[x]];
        x = parent[x];
    }
    return x;
}

void unite(int a, int b) {
    a = find(a); b = find(b);
    if (a == b) return;
    if (rnk[a] < rnk[b]) swap(a, b);
    parent[b] = a;
    if (rnk[a] == rnk[b]) rnk[a]++;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> T;
    while (T--) {
        int n;
        cin >> n;
        vector<int> a(n);
        int maxVal = 0;
        for (int i = 0; i < n; i++) {
            cin >> a[i];
            maxVal = max(maxVal, a[i]);
        }

        for (int i = 0; i <= maxVal; i++) {
            parent[i] = i;
            rnk[i] = 0;
        }

        for (int i = 0; i < n; i++) {
            int x = a[i];
            for (int f = 2; f * f <= x; f++) {
                if (x % f == 0) {
                    unite(a[i], f);
                    while (x % f == 0) x /= f;
                }
            }
            if (x > 1) unite(a[i], x);
        }

        vector<int> sorted_a(a);
        sort(sorted_a.begin(), sorted_a.end());

        bool ok = true;
        for (int i = 0; i < n; i++) {
            if (a[i] != sorted_a[i]) {
                if (a[i] <= 1 || sorted_a[i] <= 1) {
                    ok = false;
                    break;
                }
                if (find(a[i]) != find(sorted_a[i])) {
                    ok = false;
                    break;
                }
            }
        }

        cout << (ok ? "Yes" : "No") << "\n";
    }
    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
    static int[] parent = new int[1000001];
    static int[] rank = new int[1000001];

    static int find(int x) {
        while (parent[x] != x) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }

    static void unite(int a, int b) {
        a = find(a); b = find(b);
        if (a == b) return;
        if (rank[a] < rank[b]) { int t = a; a = b; b = t; }
        parent[b] = a;
        if (rank[a] == rank[b]) rank[a]++;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringBuilder sb = new StringBuilder();

        int T = Integer.parseInt(br.readLine().trim());
        while (T-- > 0) {
            int n = Integer.parseInt(br.readLine().trim());
            StringTokenizer st = new StringTokenizer(br.readLine().trim());
            int[] a = new int[n];
            int maxVal = 0;
            for (int i = 0; i < n; i++) {
                a[i] = Integer.parseInt(st.nextToken());
                maxVal = Math.max(maxVal, a[i]);
            }

            for (int i = 0; i <= maxVal; i++) {
                parent[i] = i;
                rank[i] = 0;
            }

            for (int i = 0; i < n; i++) {
                int x = a[i];
                for (int f = 2; f * f <= x; f++) {
                    if (x % f == 0) {
                        unite(a[i], f);
                        while (x % f == 0) x /= f;
                    }
                }
                if (x > 1) unite(a[i], x);
            }

            int[] sorted = a.clone();
            Arrays.sort(sorted);

            boolean ok = true;
            for (int i = 0; i < n; i++) {
                if (a[i] != sorted[i]) {
                    if (a[i] <= 1 || sorted[i] <= 1) {
                        ok = false;
                        break;
                    }
                    if (find(a[i]) != find(sorted[i])) {
                        ok = false;
                        break;
                    }
                }
            }

            sb.append(ok ? "Yes" : "No").append("\n");
        }
        System.out.print(sb);
    }
}
import sys

def main():
    data = sys.stdin.buffer.read().split()
    pos = 0

    MAX = 1000001
    spf = list(range(MAX))
    i = 2
    while i * i < MAX:
        if spf[i] == i:
            for j in range(i * i, MAX, i):
                if spf[j] == j:
                    spf[j] = i
        i += 1

    T = int(data[pos]); pos += 1
    out = []
    for _ in range(T):
        n = int(data[pos]); pos += 1
        a = [int(data[pos + i]) for i in range(n)]
        pos += n

        parent = {}

        def find(x):
            r = x
            while True:
                p = parent.get(r, r)
                if p == r:
                    break
                r = p
            while x != r:
                nx = parent.get(x, x)
                parent[x] = r
                x = nx
            return r

        def unite(a, b):
            a, b = find(a), find(b)
            if a != b:
                parent[a] = b

        for v in a:
            if v <= 1:
                continue
            x = v
            while x > 1:
                p = spf[x]
                unite(v, p)
                while x > 1 and spf[x] == p:
                    x //= p

        sorted_a = sorted(a)
        ok = True
        for i in range(n):
            if a[i] != sorted_a[i]:
                if a[i] <= 1 or sorted_a[i] <= 1:
                    ok = False
                    break
                if find(a[i]) != find(sorted_a[i]):
                    ok = False
                    break

        out.append("Yes" if ok else "No")

    sys.stdout.write('\n'.join(out) + '\n')

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];

rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    let idx = 0;
    const parent = new Int32Array(1000001);
    const rnk = new Int32Array(1000001);

    function find(x) {
        while (parent[x] !== x) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }

    function unite(a, b) {
        a = find(a); b = find(b);
        if (a === b) return;
        if (rnk[a] < rnk[b]) { let t = a; a = b; b = t; }
        parent[b] = a;
        if (rnk[a] === rnk[b]) rnk[a]++;
    }

    const res = [];
    let T = parseInt(lines[idx++]);
    while (T--) {
        const n = parseInt(lines[idx++]);
        const a = lines[idx++].split(' ').map(Number);
        let maxVal = 0;
        for (let i = 0; i < n; i++) if (a[i] > maxVal) maxVal = a[i];

        for (let i = 0; i <= maxVal; i++) { parent[i] = i; rnk[i] = 0; }

        for (let i = 0; i < n; i++) {
            let x = a[i];
            for (let f = 2; f * f <= x; f++) {
                if (x % f === 0) {
                    unite(a[i], f);
                    while (x % f === 0) x = Math.floor(x / f);
                }
            }
            if (x > 1) unite(a[i], x);
        }

        const sorted = a.slice().sort((x, y) => x - y);
        let ok = true;
        for (let i = 0; i < n; i++) {
            if (a[i] !== sorted[i]) {
                if (a[i] <= 1 || sorted[i] <= 1) { ok = false; break; }
                if (find(a[i]) !== find(sorted[i])) { ok = false; break; }
            }
        }
        res.push(ok ? "Yes" : "No");
    }
    console.log(res.join('\n'));
});

复杂度分析

  • 时间复杂度,其中 是数组中的最大值。对每个数做试除分解需要 ,排序需要 ,并查集操作近似 。Python 版本使用 SPF 筛,预处理 ,单次分解
  • 空间复杂度,用于并查集数组(或 Python 中的 SPF 筛表)。