题目链接

PEEK56 数列 k 重排

题目描述

给定一个长度为 的序列 和一个长度为 的置换序列 。我们定义一次操作为:根据 生成一个新的序列 ,其中 。然后用序列 替换原来的序列

你需要计算出,对初始序列 执行 次这样的操作后,最终得到的序列是什么。

解题思路

由于操作次数 的值可能非常大(高达 ),直接模拟是不可行的。

该问题的核心是理解操作的本质。k 次操作后,最终序列在 i 位置的值,来自于初始序列的哪个位置?

  • 1 次操作后: A_1[i] = A_0[X[i]]
  • 2 次操作后: A_2[i] = A_1[X[i]]。将 i 替换为 X[i] 代入上式,得到 A_1[X[i]] = A_0[X[X[i]]]。所以 A_2[i] = A_0[X[X[i]]]
  • k 次操作后: A_k[i] = A_0[X^k(i)],其中 X^k(i) 表示从索引 i 开始,连续应用 X 映射 k 次。

所以,我们的任务就是高效地计算出映射 Xk 次幂 X^k。这正是倍增(Binary Lifting) 算法的应用场景。

  1. 预计算: 我们创建一个二维数组 jump[p][i],存储从索引 i 开始,连续应用 X 映射 次后到达的索引。

    • 基础状态 (): jump[0][i] = X[i]
    • 递推关系: jump[p][i] = jump[p-1][jump[p-1][i]]。 预计算的时间复杂度为
  2. 计算最终映射: 对于每个目标位置 i,我们利用 jump 表计算出 X^k(i)。我们将 k 二进制分解,如果 k 的第 p 位为 1,就进行一次 的跳转。此步骤复杂度为

  3. 构造结果: 设 final_source_map[i] = X^k(i)。那么最终结果 result[i] = A_initial[final_source_map[i]]

性能优化说明: 此算法的时间复杂度为 ,在 都很大时,标准 I/O 可能会成为性能瓶颈导致超时。因此,对于 Java 和 Python 的实现,我们采用更快的 I/O 方式(BufferedReadersys.stdin.readline)来确保通过。

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;

const int LOGK = 60; // 2^60 > 10^18

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    long long k;
    cin >> n >> k;

    vector<vector<int>> jump(LOGK, vector<int>(n));
    for (int i = 0; i < n; ++i) {
        cin >> jump[0][i];
        jump[0][i]--; // 转换为0-based索引
    }

    vector<int> a(n);
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
    }

    for (int p = 1; p < LOGK; ++p) {
        for (int i = 0; i < n; ++i) {
            jump[p][i] = jump[p - 1][jump[p - 1][i]];
        }
    }

    vector<int> final_source_map(n);
    for (int i = 0; i < n; ++i) {
        int current_source = i;
        for (int p = LOGK - 1; p >= 0; --p) {
            if ((k >> p) & 1) {
                current_source = jump[p][current_source];
            }
        }
        final_source_map[i] = current_source;
    }

    vector<int> result(n);
    // 注意这里,最终的 i 位置的值,来自 a 数组的 final_source_map[i] 位置
    for (int i = 0; i < n; ++i) {
        result[i] = a[final_source_map[i]];
    }

    for (int i = 0; i < n; ++i) {
        cout << result[i] << (i == n - 1 ? "" : " ");
    }
    cout << "\n";

    return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.IOException;
import java.util.StringTokenizer;

public class Main {
    private static final int LOGK = 60;

    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter writer = new PrintWriter(System.out);
        StringTokenizer st = new StringTokenizer(reader.readLine());
        
        int n = Integer.parseInt(st.nextToken());
        long k = Long.parseLong(st.nextToken());

        int[][] jump = new int[LOGK][n];
        st = new StringTokenizer(reader.readLine());
        for (int i = 0; i < n; i++) {
            jump[0][i] = Integer.parseInt(st.nextToken()) - 1; // 转换为0-based索引
        }

        int[] a = new int[n];
        st = new StringTokenizer(reader.readLine());
        for (int i = 0; i < n; i++) {
            a[i] = Integer.parseInt(st.nextToken());
        }

        for (int p = 1; p < LOGK; p++) {
            for (int i = 0; i < n; i++) {
                jump[p][i] = jump[p - 1][jump[p - 1][i]];
            }
        }

        int[] finalSourceMap = new int[n];
        for (int i = 0; i < n; i++) {
            int currentSource = i;
            for (int p = LOGK - 1; p >= 0; p--) {
                if (((k >> p) & 1) == 1) {
                    currentSource = jump[p][currentSource];
                }
            }
            finalSourceMap[i] = currentSource;
        }

        int[] result = new int[n];
        for (int i = 0; i < n; i++) {
            result[i] = a[finalSourceMap[i]];
        }

        for (int i = 0; i < n; i++) {
            writer.print(result[i] + (i == n - 1 ? "" : " "));
        }
        writer.println();
        writer.flush();
    }
}
import sys

def solve():
    # 为应对大数据量,使用 sys.stdin.readline 以提高I/O效率
    input = sys.stdin.readline
    
    LOGK = 60
    
    n, k = map(int, input().split())
    x_map = [val - 1 for val in map(int, input().split())] # 转换为0-based
    a = list(map(int, input().split()))
    
    jump = [[0] * n for _ in range(LOGK)]
    jump[0] = x_map

    for p in range(1, LOGK):
        for i in range(n):
            jump[p][i] = jump[p - 1][jump[p - 1][i]]

    final_source_map = [0] * n
    for i in range(n):
        current_source = i
        for p in range(LOGK - 1, -1, -1):
            if (k >> p) & 1:
                current_source = jump[p][current_source]
        final_source_map[i] = current_source

    result = [0] * n
    for i in range(n):
        result[i] = a[final_source_map[i]]
        
    print(*result)

solve()

算法及复杂度

  • 算法:倍增 (Binary Lifting)
  • 时间复杂度
  • 空间复杂度