题目链接
题目描述
给定一个 的整数方阵
以及一个非负整数
,请计算矩阵
。当
时,约定
为
单位矩阵
。所有计算结果对
取模。
解题思路
本题要求计算一个矩阵的 次幂。如果
非常大,直接进行
次矩阵乘法(
)会超时。这是一个典型的可以使用快速幂 (Binary Exponentiation) 思想来解决的问题。
常规的快速幂用于计算一个数的幂(例如 ),其核心思想是将指数
进行二进制拆分,从而将时间复杂度从
优化到
。同样的思想也适用于矩阵乘法,因为矩阵乘法满足结合律(即
),这是使用快速幂算法的前提。
算法步骤
-
定义矩阵乘法: 首先,我们需要一个函数来计算两个
矩阵的乘积。设
,则矩阵
中的每个元素
由以下公式计算得出:
在计算过程中,每次乘法和加法的结果都需要对
取模,以防止溢出。特别需要注意,中间结果可能为负数,取模时需要确保结果落在
区间内。此操作的时间复杂度为
。
-
矩阵快速幂算法: 我们将整数快速幂的算法应用于矩阵:
- 初始化一个结果矩阵
为单位矩阵(主对角线为1,其余为0)。
- 初始化一个基底矩阵
为输入的矩阵
。
- 对指数
进行循环,直到
变为 0:
- 如果
的当前二进制最低位为 1(即
),则将
乘以
:
。
- 将
自乘:
。
- 将
右移一位(即
)。
- 如果
- 循环结束后,
矩阵即为最终结果
。
- 初始化一个结果矩阵
整个算法需要进行 次矩阵乘法,每次乘法复杂度为
。
代码
#include <iostream>
#include <vector>
using namespace std;
const int MOD = 1e9 + 7;
// 定义矩阵类型
using Matrix = vector<vector<long long>>;
// 矩阵乘法
Matrix multiply(const Matrix& a, const Matrix& b, int n) {
Matrix c(n, vector<long long>(n, 0));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
for (int l = 0; l < n; ++l) {
long long product = a[i][l] * b[l][j];
c[i][j] = (c[i][j] + product % MOD + MOD) % MOD;
}
}
}
return c;
}
// 矩阵快速幂
Matrix matrix_pow(Matrix base, long long exp, int n) {
Matrix res(n, vector<long long>(n, 0));
// 初始化为单位矩阵
for (int i = 0; i < n; ++i) {
res[i][i] = 1;
}
while (exp > 0) {
if (exp % 2 == 1) {
res = multiply(res, base, n);
}
base = multiply(base, base, n);
exp /= 2;
}
return res;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
long long k;
cin >> n >> k;
Matrix a(n, vector<long long>(n));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
cin >> a[i][j];
}
}
Matrix result = matrix_pow(a, k, n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
cout << result[i][j] << (j == n - 1 ? "" : " ");
}
cout << "\n";
}
return 0;
}
import java.util.Scanner;
public class Main {
static final int MOD = 1_000_000_007;
// 矩阵乘法
public static long[][] multiply(long[][] a, long[][] b, int n) {
long[][] c = new long[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int l = 0; l < n; l++) {
long product = a[i][l] * b[l][j];
c[i][j] = (c[i][j] + product % MOD + MOD) % MOD;
}
}
}
return c;
}
// 矩阵快速幂
public static long[][] matrixPow(long[][] base, long exp, int n) {
long[][] res = new long[n][n];
// 初始化为单位矩阵
for (int i = 0; i < n; i++) {
res[i][i] = 1;
}
while (exp > 0) {
if (exp % 2 == 1) {
res = multiply(res, base, n);
}
base = multiply(base, base, n);
exp /= 2;
}
return res;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long k = sc.nextLong();
long[][] a = new long[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
a[i][j] = sc.nextLong();
}
}
long[][] result = matrixPow(a, k, n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
System.out.print(result[i][j] + (j == n - 1 ? "" : " "));
}
System.out.println();
}
}
}
import sys
MOD = 10**9 + 7
def multiply(a, b, n):
c = [[0] * n for _ in range(n)]
for i in range(n):
for j in range(n):
for l in range(n):
c[i][j] = (c[i][j] + a[i][l] * b[l][j]) % MOD
return c
def matrix_pow(base, exp, n):
res = [[0] * n for _ in range(n)]
# 初始化为单位矩阵
for i in range(n):
res[i][i] = 1
while exp > 0:
if exp % 2 == 1:
res = multiply(res, base, n)
base = multiply(base, base, n)
exp //= 2
return res
def main():
try:
input = sys.stdin.readline
n, k = map(int, input().split())
a = []
for _ in range(n):
a.append(list(map(int, input().split())))
result = matrix_pow(a, k, n)
for i in range(n):
sys.stdout.write(" ".join(map(str, result[i])) + '\n')
except (IOError, ValueError):
return
main()
算法及复杂度
- 算法:矩阵快速幂 (Matrix Exponentiation by Squaring)
- 时间复杂度:
。其中
是单次矩阵乘法的复杂度,
是快速幂算法所需的乘法次数。
- 空间复杂度:
,用于存储矩阵。