题意
给出n个数,其中值为-1的需要从[1,m]中选一个数替代。要求替代后,不能出现某个子序列是回文的,求方案数。
题解
题目要求不能是回文,这就等价于不能出现 a[i]==a[i+2]的情况,这个转换很重要。
那么我们就可以根据奇偶分开来处理,答案就是两者的乘积。
接下来是具体求方案数
设序列 A,−1,−1,...,−1,−1,B ,其中连续的 −1的个数为 k
设 S[k]表示长为 k的,两端相同的方案数, D[k]表示长为 k的,两端不同的方案数
S[0]=0,D[0]=1
当k为奇数
- A==B
S[k]=S[k/2]2+(m−1)∗D[k/2]2 - A!=B
D[k]=S[k/2]∗D[k/2]∗2+(m−2)∗D[k/2]2
当k为偶数
- A==B
S[k]=(m−1)∗D[k−1] - A!=B
D[k]=S[k−1]+(m−2)∗D[k−1]
以上公式推导方法为
当k为奇数,枚举中间的数分别为: A、B、其他
当k为偶数,枚举最后一个为 −1的数分别为: A、B、其他
需要注意的是,当 A、B都为空,或其中一个为空时,需要特殊考虑。
代码
#include<bits/stdc++.h>
#define N 100010
#define INF 0x3f3f3f3f
#define eps 1e-10
// #define pi 3.141592653589793
#define P 998244353
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define mem(x) memset(x,0,sizeof x)
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/priority_queue.hpp>
using namespace __gnu_pbds;
typedef pair<int,int> pi;
typedef pair<int,pi> pp;
typedef __gnu_pbds::priority_queue<pp,greater<pp> > heap;
LL S[N],D[N],n,m,d[N];
vector<int> a,b;
LL qumi(LL x,LL y){
LL res=1;
while(y){if (y&1) res=res*x%P;x=x*x%P;y>>=1;}
return res;
}
LL spy(vector<int> &a){
int num=0,fg=0;
for (auto i:a) if (i==-1) num++;
if (num==a.size()){
return m*qumi(m-1,a.size()-1)%P;
}
num=0;for (auto i:a) d[++num]=i;
fg=0; d[0]=1;d[num+1]=1;
LL res=1;
for (int i=1;i<=num;i++){
if (i<num) if (d[i]==d[i+1] && d[i]>0) return 0;
if (d[i]==-1 && d[i-1]>0) fg=i-1;
if (d[i]==-1 && d[i+1]>0){
if (!fg) res=res*qumi(m-1,i)%P;else
if (i==num) res=res*qumi(m-1,i-fg)%P;else
if (d[fg]==d[i+1]) res=res*S[i-fg]%P;else
res=res*D[i-fg]%P;
}
}
return res;
}
int main()
{
cin>>n>>m;
S[0]=0; D[0]=1;
for (int i=1,k;i<=n;i+=2){
k=i;
S[k]=(S[k/2]*S[k/2]%P+(m-1)*D[k/2]%P*D[k/2])%P;
D[k]=(S[k/2]*D[k/2]*2%P+D[k/2]*D[k/2]%P*(m-2))%P;
k++;
S[k]=D[i]*(m-1)%P;
D[k]=(S[i]+D[i]*(m-2))%P;
}
for (int i=1,x;i<=n;i+=2){
sc(x); a.pb(x);
if (i<n) {sc(x),b.pb(x);}
}
cout<<spy(a)*spy(b)%P;
return 0;
}