题意
做n次选择,每次选择(A,B,C,D,AB,AC,AD,BC,BD,CD,ABC,ACD,ABD,BCD,ABCD)
中的一种
限制,每次选择完后,从最开始到当前选择的所有内容中,A和C出现次数差的绝对值不大于1,B和D出现次数差的绝对值不大于2
给定 n≤1e5, 求完成n次选择的方案数
算法
直接模拟递推
把题目抽象成数学。
- 每次在1到15 中选择一个数,以下是简单的映射关系
A | B | C | D | AB | AC | AD | BC | BD | CD | ABC | ACD | ABD | BCD | ABCD |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0001 | 0010 | 0100 | 1000 | 0011 | 0101 | 1001 | 0110 | 1010 | 1100 | 0111 | 1101 | 1011 | 1110 | 1111 |
A,B,C,D是否被选,就和对应二进制位是否为1对应
- 记录A和C个数的差值,记录B和D个数的差值,并按照题目限制控制合法的方案。
令 ans[i][j]= 表示到当前位置,AC的差值为i,BD的差值为j的方案数
每次模拟选择一个数(1~15)
ans当前[AC][BD]+=ans上一次[AC−选择数导致的A和C差值的变化][BD−选择数导致的B和D差值的变化]
- 变成代码
注意到 C++ 下标不能使用负数,我们分别做值映射
对于AC的差值 −1,0,1=>0,1,2
对于BD的差值 −2,−1,0,1,2=>0,1,2,3,4
所以默认值 ans[i][j]=0,其中 ans[1][2]=1 表示,还未选择前,AC和BD的差值都为0的情况
以此循环n次,可以计算出n次选择后,AC和BD不同差值的方案数
最终的答案为∑i=0..2,j=0..4ans[i][j]
对于需要取模的部分注意取模即可
代码
class Solution {
public:
/**
* 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
* @param n int整型
* @return int整型
*/
int solve(int n) {
const int mod = 1000000007;
// arr[i][j] =
// A的个数减去C的个数+1为i
// B的个数减去D的个数+2为j
// 时的方案数
vector<vector<long long> > arr = vector<vector<long long> >(3,vector<long long>(5,0));
arr[1][2] = 1;
// 递推n个位置
for(int i = 0;i < n;i++){
// 状态转移的结果为ans
vector<vector<long long> > ans = vector<vector<long long> >(3,vector<long long>(5,0));
for(int ac = 0;ac < 3;ac++){ // 枚举现有的ac差
for(int bd = 0;bd < 5;bd++){ // 枚举现有的bd的差
for(int j = 1;j < 16;j++){ // 枚举选项
int diff_ac = (j>>0)%2 - (j>>1)%2; // 计算ac差的变化
int diff_bd = (j>>2)%2 - (j>>3)%2; // 计算bd差的变化
int new_ac = ac+diff_ac; // 新的ac的差
if( new_ac < 0 || new_ac >= 3)continue;
int new_bd = bd+diff_bd; // 新的bd的差
if( new_bd < 0 || new_bd >= 5)continue;
(ans[new_ac][new_bd] += arr[ac][bd])%=mod; // 状态转移
}
}
}
arr = ans;
}
long long result = 0;
for(int ac = 0;ac<3;ac++){ // 枚举ac的差
for(int bd = 0;bd < 5;bd++){ // 枚举bd的差
(result += arr[ac][bd])%=mod; // 统计答案
}
}
return result;
}
};
复杂度分析
时间复杂度: 我们循环了n次,每一次模拟选择,模拟选择的代价是常数3⋅5⋅16, 所以 总时间复杂度为O(n)
空间复杂度: 我们仅用了 一个常数大小的结果数组,和一个常数大小的临时数组来记录方案,所以空间复杂度是常数O(1)
矩阵乘法/快速幂
我们发现,上面的递推关系中,与n无关,且每次转换关系又是线性加和。
所以我们把AC和BD的值看成一个状态整体(代码中encode
函数实现),有3⋅5=15种
不同状态整体的转义系数是常数,满足这个条件,就可以变成矩阵乘法 其中i行j列表示,上一个状态是i,转移为状态j的方案数,矩阵为15⋅15的大小,太大不要手动推算,矩阵具体的值由代码算出。
而矩阵乘法可以使用快速幂来提高效率
考虑到初始 矩阵为 (0⋯1⋯0), 仅有表示AC和BD差值为0的项(也就是 encode(1,2)
)为1
所以最终的答案为 ∑i=0..15(basematrix)n[encode(1,2)][i]
代码
class Solution {
public:
typedef long long ll;
#define rep(i,a,n) for (ll i=a;i<n;i++)
const int mod = 1000000007;
// 矩阵乘法
vector<vector<ll>> mul(vector<vector<ll>>& m1,vector<vector<ll>>& m2){
vector<vector<ll>> res = vector<vector<ll>>(m1.size(),vector<ll>(m2[0].size(),0));
rep(i,0,m1.size()){
rep(j,0,m2[0].size()){
rep(k,0,m1[0].size()){
(res[i][j]+=m1[i][k]*m2[k][j]%mod)%=mod;
}
}
}
return res ;
}
// 矩阵幂次
vector<vector<ll>> matrixp(vector<vector<ll>>& m1, ll pwr){
// 单位矩阵
vector<vector<ll>> res = vector<vector<ll>>(m1.size(),vector<ll>(m1.size(),0));
rep(i,0,m1.size()){
res[i][i] = 1;
}
// 快速幂
while(pwr){ // 幂次不为0
if(pwr%2)res = mul(res,m1); // 当前二进制位为1 则乘上翻倍后的基数
m1 = mul(m1,m1); // 基数翻倍
pwr/=2; // 幂次除以2
}
return res ;
}
int encode(int v0,int v1){ // v1 最大小于5,所以编码两个数成一个数作为状态
return v0*5+v1;
}
/**
* 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
* @param n int整型
* @return int整型
*/
int solve(int n) {
const int sz = 3*5;
// 状态转换关系计算
vector<vector<long long> > matrix = vector<vector<long long> >(sz,vector<long long>(sz,0));
for(int ac = 0;ac < 3;ac++){ // 枚举所有A-C+1的差
for(int bd = 0;bd < 5;bd++){ // 枚举所有B-D+2的差
for(int j = 1;j < 16;j++){ // 枚举所有选项
int diff_ac = (j>>0)%2 - (j>>1)%2; // AC差的变化
int diff_bd = (j>>2)%2 - (j>>3)%2; // BD差的变化
int new_ac = ac+diff_ac; // 新的AC的差
if( new_ac < 0 || new_ac >= 3)continue;
int new_bd = bd+diff_bd; // 新的BD的差
if( new_bd < 0 || new_bd >= 5)continue;
matrix[encode(ac,bd)][encode(new_ac,new_bd)] += 1; // 写状态转移矩阵
}
}
}
// 矩阵n次方
vector<vector<long long> > matrixResult = matrixp(matrix, n);
long long result = 0;
for(int idx = 0;idx<sz;idx++){ // 所有合法的状态
(result += matrixResult[encode(1,2)][idx])%=mod; // 统计合法的结果
}
return result;
}
};
复杂度分析
时间复杂度: 我们通过快速幂,计算的是转换矩阵的n次方,矩阵大小为常数,所以总时间复杂度为O(log(n))
空间复杂度: 我们仅用了 一个常数大小的结果矩阵,和一个计算矩阵幂次的非递归函数,所以空间复杂度是常数O(1)
知识点
- 对问题的抽象化,虽然题目是ABCD,但是实际上因为是选择所有情况,熟悉二进制的应该能立刻想到1到15能完成一一映射
- int 的溢出,虽然输入输出都是int,但是 涉及到int的加法乘法,可能有溢出的情况时,记得使用long long 来完成中间过程的运算避免溢出
- 与项数无关的递推式,可以想到矩阵乘法