题目链接

【模板】矩阵快速幂

题目描述

给定一个 的整数方阵 以及一个非负整数 ,请计算矩阵 。当 时,约定 单位矩阵 。所有计算结果对 取模。

解题思路

本题要求计算一个矩阵的 次幂。如果 非常大,直接进行 次矩阵乘法()会超时。这是一个典型的可以使用快速幂 (Binary Exponentiation) 思想来解决的问题。

常规的快速幂用于计算一个数的幂(例如 ),其核心思想是将指数 进行二进制拆分,从而将时间复杂度从 优化到 。同样的思想也适用于矩阵乘法,因为矩阵乘法满足结合律(即 ),这是使用快速幂算法的前提。

算法步骤

  1. 定义矩阵乘法: 首先,我们需要一个函数来计算两个 矩阵的乘积。设 ,则矩阵 中的每个元素 由以下公式计算得出: 在计算过程中,每次乘法和加法的结果都需要对 取模,以防止溢出。特别需要注意,中间结果可能为负数,取模时需要确保结果落在 区间内。此操作的时间复杂度为

  2. 矩阵快速幂算法: 我们将整数快速幂的算法应用于矩阵:

    • 初始化一个结果矩阵 单位矩阵(主对角线为1,其余为0)。
    • 初始化一个基底矩阵 为输入的矩阵
    • 对指数 进行循环,直到 变为 0:
      • 如果 的当前二进制最低位为 1(即 ),则将 乘以
      • 自乘:
      • 右移一位(即 )。
    • 循环结束后, 矩阵即为最终结果

整个算法需要进行 次矩阵乘法,每次乘法复杂度为

代码

#include <iostream>
#include <vector>

using namespace std;

const int MOD = 1e9 + 7;

// 定义矩阵类型
using Matrix = vector<vector<long long>>;

// 矩阵乘法
Matrix multiply(const Matrix& a, const Matrix& b, int n) {
    Matrix c(n, vector<long long>(n, 0));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            for (int l = 0; l < n; ++l) {
                long long product = a[i][l] * b[l][j];
                c[i][j] = (c[i][j] + product % MOD + MOD) % MOD;
            }
        }
    }
    return c;
}

// 矩阵快速幂
Matrix matrix_pow(Matrix base, long long exp, int n) {
    Matrix res(n, vector<long long>(n, 0));
    // 初始化为单位矩阵
    for (int i = 0; i < n; ++i) {
        res[i][i] = 1;
    }

    while (exp > 0) {
        if (exp % 2 == 1) {
            res = multiply(res, base, n);
        }
        base = multiply(base, base, n);
        exp /= 2;
    }
    return res;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    long long k;
    cin >> n >> k;

    Matrix a(n, vector<long long>(n));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> a[i][j];
        }
    }

    Matrix result = matrix_pow(a, k, n);

    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            cout << result[i][j] << (j == n - 1 ? "" : " ");
        }
        cout << "\n";
    }

    return 0;
}
import java.util.Scanner;

public class Main {
    static final int MOD = 1_000_000_007;

    // 矩阵乘法
    public static long[][] multiply(long[][] a, long[][] b, int n) {
        long[][] c = new long[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                for (int l = 0; l < n; l++) {
                    long product = a[i][l] * b[l][j];
                    c[i][j] = (c[i][j] + product % MOD + MOD) % MOD;
                }
            }
        }
        return c;
    }

    // 矩阵快速幂
    public static long[][] matrixPow(long[][] base, long exp, int n) {
        long[][] res = new long[n][n];
        // 初始化为单位矩阵
        for (int i = 0; i < n; i++) {
            res[i][i] = 1;
        }

        while (exp > 0) {
            if (exp % 2 == 1) {
                res = multiply(res, base, n);
            }
            base = multiply(base, base, n);
            exp /= 2;
        }
        return res;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long k = sc.nextLong();

        long[][] a = new long[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                a[i][j] = sc.nextLong();
            }
        }

        long[][] result = matrixPow(a, k, n);

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(result[i][j] + (j == n - 1 ? "" : " "));
            }
            System.out.println();
        }
    }
}
import sys

MOD = 10**9 + 7

def multiply(a, b, n):
    c = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            for l in range(n):
                c[i][j] = (c[i][j] + a[i][l] * b[l][j]) % MOD
    return c

def matrix_pow(base, exp, n):
    res = [[0] * n for _ in range(n)]
    # 初始化为单位矩阵
    for i in range(n):
        res[i][i] = 1

    while exp > 0:
        if exp % 2 == 1:
            res = multiply(res, base, n)
        base = multiply(base, base, n)
        exp //= 2
    return res

def main():
    try:
        input = sys.stdin.readline
        n, k = map(int, input().split())
        
        a = []
        for _ in range(n):
            a.append(list(map(int, input().split())))

        result = matrix_pow(a, k, n)

        for i in range(n):
            sys.stdout.write(" ".join(map(str, result[i])) + '\n')

    except (IOError, ValueError):
        return

main()

算法及复杂度

  • 算法:矩阵快速幂 (Matrix Exponentiation by Squaring)
  • 时间复杂度:。其中 是单次矩阵乘法的复杂度, 是快速幂算法所需的乘法次数。
  • 空间复杂度:,用于存储矩阵。