title: '[动态规划]矩阵加速'
date: 2026-03-31 21:33:43
tags: 动态规划 数论
矩阵加速DP转移
前置知识:矩阵快速幂
给定一个 阶的矩阵
以及一个非负整数
,要计算矩阵
,当
时,
是
阶单位矩阵
。
即有,其中有
个
。
若 数值小的时候,可以直接暴力计算。但是如果
的情况时,暴力做法必然会超时。
在此之前我们学过快速幂:即利用二进制的思想,将原来的底数不断扩大,做到 的时间复杂度求出一个数的
次方。
i64 binpow(i64 a,i64 b,i64 m) {
a%=m;
i64 res=1;
while(b>0) {
if(b&1)res=res*a%m;
a=a*a%m;
b>>=1;
}
return res;
}
利用这个思想,我们可以把矩阵看作一个整体,套上整数快速幂的模板:
//创建Mat结构体,包含矩阵乘法
Mat binpow(Mat A,i64 x,i64 mod){
Mat I;
I.init_b();//初始化单位矩阵
while(x){
if(x&1) I=I*A;
A=A*A;
x>>=1;
}
return I;
}
求矩阵的 次方这个问题就迎刃而解了。
矩阵加速
我们知道,动态规划可以看作一个递推的过程。其在遍历的过程中必然会推导出一个状态转移方程,只要整理清楚了这个状态转移方程,就可以对这个状态转移方程进行矩阵加速。举一个非常常见的例子:
斐波那契数列:
前面两个初始值不用太关心,我们主要关心当 时的情况:
我们可以得到:
假设向量 :
因为具有递推关系故存在一个 阶矩阵
于
前面才能进行状态转移,即:
也就是:
然后我们就可以利用矩阵的乘法运算,用待定系数法求出这四个未知数。
于是我们就可以得到:
假设矩阵 经过
自乘之后变成了:
由于矩阵乘法就有 ;
矩阵自乘的过程可以运用矩阵快速幂。因此上式所求出的值即是第 项斐波那契数列的值(一般会进行取模操作)。
这样一来我们就可以以 的时间复杂度求出第
项的各个值了。
是矩阵的阶数。
例题
https://www.luogu.com.cn/problem/P3216
P3216 [HNOI2011] 数学作业
题目描述
小 C 数学成绩优异,于是老师给小 C 留了一道非常难的数学作业题:
给定正整数 ,要求计算
的值,其中
是将
所有正整数 顺序连接起来得到的数。
例如,,
。小 C 想了大半天终于意识到这是一道不可能手算出来的题目,于是他只好向你求助,希望你能编写一个程序帮他解决这个问题。
【数据范围】
对于 的数据,
;
对于 的数据,
,
。
不妨枚举一下:
,即有
,
,即有
,
有,
我们可以鲁莽的得到这个题目的状态转移方程:
得出 状态转移方程就很简单了。但是题目中的
,数据特别大,可以考虑矩阵乘法。
对于 ,在某一段范围是固定的,比如说,
中这些数的
都是固定的。于是我们对n 进行分段考虑。
接下来我们推导如何用矩阵转移这个递推式
我们先对中的项进行转化,全部变成包含
的式子。
则我们创建两个 的向量
以及一个
的矩阵
,
则有以下关系:
由上述状态转移方程有:
以及:
故有:
由于 ,所以有
,
,所以
,
1=1,所以。
可以得出矩阵 :
由此可以得到:
最后再对的长度进行
状态转移矩阵快速幂即可:
Code.:
struct Mat{
i64 mat[4][4];
Mat(){
memset(mat,0,sizeof(mat));//矩阵初始化为0
}
void init_dp(i64 sz){//转移矩阵初始化
mat[1][2]=1;
mat[1][3]=1;
mat[2][2]=1;
mat[2][3]=1;
mat[3][3]=1;
mat[1][1]=sz;
mat[1][1]%=m;
}
void init_b(){//初始化为单位矩阵
mat[1][1]=1;
mat[2][2]=1;
mat[3][3]=1;
}
Mat operator * (const Mat &other){//结构体定义矩阵乘法
Mat temp;
for(int i=1;i<=3;i++){
for(int j=1;j<=3;j++){
i128 res=0;
for(int k=1;k<=3;k++){
res+=mat[i][k]*other.mat[k][j];
res%=m;
}
temp.mat[i][j]=res;
}
}
return temp;
}
};
Mat binpow(Mat A,i64 x,i64 mod){//矩阵快速幂
Mat I;
I.init_b();
while(x){
if(x&1) I=I*A;
A=A*A;
x>>=1;
}
return I;
}
void solve(){
cin>>n>>m;
i64 f=0;
for(i128 i=1,L=1;L<=n;i++,L*=10){
i128 R=min((i128)L*10-1,(i128)n);
i128 d=R-L+1;
Mat base;
base.init_dp(binpow(10,i,m));
Mat temp=binpow(base,d,m);
i64 nf=(temp.mat[1][1]*f+temp.mat[1][2]*((L-1)%m)+temp.mat[1][3])%m;
f=nf;
}
cout<<f<<endl;
}
记得开__int128,不然会溢出超时。
END.

京公网安备 11010502036488号