题面:
题意:
给定一些约束条件求目标函数的最大值。
其中 det(A) ≠ 0 (mod 998244353),保证了在 mod 998244353 下矩阵A 可逆。
题解:
没有想明白为什么会在 ∑i=1n∑j=1nAi,jxixj=1 的条件下计算目标函数的最大值。
我们假设目标函数为 f(x1,...,xn)=∑i=1nbixi,因为最终求的是平方,那么一定在 f取极值时,最终答案取极值
约束条件为 g(x1,...,xn)=∑i=1n∑j=1nAi,jxixj=1
拉个朗日函数为 L(x1,...,xn,λ)=∑i=1nbixi+λ(∑i=1n∑j=1nAi,jxixj−1)
对L的每个变量求偏导,求偏导的时候 ∑i=1n∑j=1nAi,jxixj拆开即可。
注意 Ai,j=Aj,i,矩阵A为对称矩阵
⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧b1+2∗λ(A1,1x1+A1,2x2+...+A1,nxn)=0b2+2∗λ(A2,1x1+A2,2x2+...+A2,nxn)=0...bn+2∗λ(An,1x1+An,2x2+...+An,nxn)=0∑i=1n∑j=1nAi,jxixj=1
即
{B+2λAx=0①xTAx=1②
B+2λAx=0----> 2λAx=−B----> x=−2λA−1∗B
∑xibi=xTB=BTx----> BTx=BT(−2λA−1∗B)=xTB----> xT=−BT2λA−1
xTAx=1----> −BT2λA−1∗A∗−2λA−1∗B=1----> 4λ21BTA−1B=1
(∑BTx)2=(−2λ1BTA−1B)2=4λ21(BTA−1B)(BTA−1B)=(BTA−1B)
求解 BTA−1B即可。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=210;
const int maxm=100100;
const int up=100000;
struct node
{
int n,m;
int a[maxn][maxn];
void init(void)
{
memset(a,0,sizeof(a));
for(int i=1;i<=n;i++)
a[i][i]=1;
}
void input(void)
{
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
scanf("%d",&a[i][j]);
}
}
void _swap(int x,int y)
{
for(int i=1;i<=n;i++)
swap(a[x][i],a[y][i]);
}
void mul_k(int x,int k)
{
for(int i=1;i<=n;i++)
a[x][i]=(ll)a[x][i]*k%mod;
}
void mul_k_add(int x,int k,int y)
{
for(int i=1;i<=n;i++)
a[y][i]=((a[y][i]+(ll)a[x][i]*k)%mod+mod)%mod;
}
void print(void)
{
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
printf("%d ",a[i][j]);
putchar('\n');
}
}
node getT(void)
{
node ans;
ans.n=m,ans.m=n;
for(int i=1;i<=m;i++)
{
for(int j=1;j<=n;j++)
ans.a[i][j]=a[j][i];
}
return ans;
}
node operator * (const node &b) const
{
node ans;
memset(ans.a,0,sizeof(ans.a));
ans.n=n,ans.m=b.m;
for(int i=1;i<=n;i++)
{
for(int j=1;j<=b.m;j++)
{
for(int k=1;k<=m;k++)
ans.a[i][j]=(ans.a[i][j]+1ll*a[i][k]*b.a[k][j])%mod;
}
}
return ans;
}
}a,inva,b,bt,ans;
int mypow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=(ll)ans*a%mod;
a=(ll)a*a%mod;
b>>=1;
}
return ans;
}
void get(node &a,node &b)
{
b.n=b.m=a.n;
b.init();
int n=a.n;
for(int i=1;i<=n;i++)
{
if(!a.a[i][i])
{
for(int j=i+1;j<=n;j++)
{
if(a.a[j][i])
{
a._swap(i,j);
b._swap(i,j);
break;
}
}
}
b.mul_k(i,mypow(a.a[i][i],mod-2));
a.mul_k(i,mypow(a.a[i][i],mod-2));
for(int j=i+1;j<=n;j++)
{
b.mul_k_add(i,-a.a[j][i],j);
a.mul_k_add(i,-a.a[j][i],j);
}
}
for(int i=n;i>=1;i--)
{
for(int j=i-1;j>=1;j--)
{
b.mul_k_add(i,-a.a[j][i],j);
a.mul_k_add(i,-a.a[j][i],j);
}
}
}
int main(void)
{
int n;
while(scanf("%d",&n)!=EOF)
{
a.n=n,a.m=n;
a.input();
b.n=n,b.m=1;
b.input();
get(a,inva);
printf("%d\n",(b.getT()*inva*b).a[1][1]);
}
return 0;
}