(牛客第四场)子段乘积(尺取法、拓展欧几里得算法、矩阵快速幂、逆元、费马小定理)

链接

给出长度为n的数列,求其长度为k的连续字段的乘积对取模余数的最大值。

先看几个预备知识:

拓展欧几里得算法:对于,一定有一组整数解

利用数学归纳法证明:

①当b=0时,y=0,x=1,显然成立。

②假设成立,那么对于也成立。

③假设存在解使得

④需要证明存在解

移项得:。将其代入要证明的式子得:
$x_1=y_2,y_1=(x_2-k*y_2)$就行了。

模板:

void ext_gcd(int a,int b,int &d,int &x,int &y){
    if(!b) {d=a;x=1;y=0;}
    else {gcd(b,a%b,d,y,x);y-=x*(a/b);}
}

逆元,对于除法的取模我们不能像加减和乘法那样直接取模,例如的式子是错误的,那么如果我们想要对除法结果取模,我们该怎么办?答案是用逆元。

假如有像满足的式子存在,y就被称为a的逆元。

逆元可以实现上面的除法取模的问题。

对于求解逆元,我们通常有两种方法:

①拓展欧几里得算法

可以写成,也就等同于是,求解这个方程组。返回即可。

模板:

ll inv(ll a,ll p){
    ll x,y,d;
    ext_gcd(a,p,d,x,y);
    x=(x%p+p)%p;   //将x从负数转化为正数
    return x;
}

②费马小定理

费马小定理为素数,对于任意整数都有

无法被整除,则有

根据费马小定理,我们可以通过矩阵快速幂求得逆元,因为由上式有

费马小定理有个限制,就是模数必须为素数,不过对于本题是可以的。

接下来就是用尺取法求解答案,对于本题,采用的方法是维护变量代表字段的最后一位数,temp代表以为结尾的k长度子串的乘积。

/////拓展欧几里得法求逆元//////////
#include <iostream>
#include <cstdio>
#include <stack>
#include <string>
using namespace std;
typedef long long ll;
const int maxn=200010;
const ll mod=998244353;
ll a[maxn];
//如上求逆元
void ext_gcd(ll a,ll b,ll &d,ll &x,ll &y){
    if(!b){d=a; x=1; y=0;}
    else {ext_gcd(b,a%b,d,y,x); y-=x*(a/b);}
}
ll inv(ll a,ll p){
    ll x,y,d;
    ext_gcd(a,p,d,x,y);
    x=(x%p+p)%p;
    return x;
}

int main(){
    int n,k;
    cin>>n>>k;
    ll zero=0;//当前长度为k的字段有多少0,我们维护这个变量的意义是,当子段里有0,那么子段的乘积就为0
    ll ans=0;
    ll temp=1;    //temp维护区间内除0以外的数的乘积
    //尺取法
    for(int i=0;i<n;i++){
        cin>>a[i];
        if(!a[i]) zero++;     //如果遇到0,则加一
        else temp=temp*a[i]%mod;  //否则乘以这个数

        if(i>=k){
            if(!a[i-k]) zero--; //如果k长度子段第一个元素前一个元素为0,说明我们的区间已经移动过去了
            else temp=temp*inv(a[i-k],mod)%mod; //否则乘以它的逆元,因为0没有逆元
        }
        if(i>=k-1&&!zero) ans=ans>temp?ans:temp;//如果zero=0,则说明该区间里面的数全部大于0,就可以和答案比较了
    }
    printf("%lld",ans);
    return 0;
}
//////费马小定理(快速幂)求逆元//////////
#include <iostream>
#include <cstdio>
#include <stack>
#include <string>
using namespace std;
typedef long long ll;
const int maxn=200010;
const ll mod=998244353;
ll a[maxn];
//如上求逆元
ll poww(ll a,ll n){
    ll res=1;
    while(n>0){
        if(n&1) res=(res*a)%mod;
        a=(a*a)%mod;
        n>>=1;
    }
    return res;
}
// 就是求a^p-2,就这么简单
ll inv(ll a,ll p){
    return poww(a,p-2); 
}

int main(){
    int n,k;
    cin>>n>>k;
    ll zero=0;//当前长度为k的字段有多少0,我们维护这个变量的意义是,当子段里有0,那么子段的乘积就为0
    ll ans=0;
    ll temp=1;    //temp维护区间内除0以外的数的乘积
    //尺取法
    for(int i=0;i<n;i++){
        cin>>a[i];
        if(!a[i]) zero++;     //如果遇到0,则加一
        else temp=temp*a[i]%mod;  //否则乘以这个数

        if(i>=k){
            if(!a[i-k]) zero--; //如果k长度子段第一个元素前一个元素为0,说明我们的区间已经移动过去了
            else temp=temp*inv(a[i-k],mod)%mod; //否则乘以它的逆元,因为0没有逆元
        }
        if(i>=k-1&&!zero) ans=ans>temp?ans:temp;//如果zero=0,则说明该区间里面的数全部大于0,就可以和答案比较了
    }
    printf("%lld",ans);
    return 0;
}