题目链接

题面:

题意:
给定一些约束条件求目标函数的最大值。
其中 det(A) ≠ 0 (mod 998244353),保证了在 mod 998244353 下矩阵A 可逆。

题解:
没有想明白为什么会在 i = 1 n j = 1 n A i , j x i x j = 1 \sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j=1 i=1nj=1nAi,jxixj=1 的条件下计算目标函数的最大值。

我们假设目标函数为 f ( x 1 , . . . , x n ) = i = 1 n b i x i f(x_1,...,x_n)=\sum_{i=1}^nb_ix_i f(x1,...,xn)=i=1nbixi,因为最终求的是平方,那么一定在 f f f取极值时,最终答案取极值

约束条件为 g ( x 1 , . . . , x n ) = i = 1 n j = 1 n A i , j x i x j = 1 g(x_1,...,x_n)=\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j=1 g(x1,...,xn)=i=1nj=1nAi,jxixj=1

拉个朗日函数为 L ( x 1 , . . . , x n , λ ) = i = 1 n b i x i + λ ( i = 1 n j = 1 n A i , j x i x j 1 ) L(x1,...,xn,\lambda)=\sum_{i=1}^nb_ix_i+\lambda(\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j-1) L(x1,...,xn,λ)=i=1nbixi+λ(i=1nj=1nAi,jxixj1)

对L的每个变量求偏导,求偏导的时候 i = 1 n j = 1 n A i , j x i x j \sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j i=1nj=1nAi,jxixj拆开即可。
注意 A i , j = A j , i A_{i,j}=A_{j,i} Ai,j=Aj,i,矩阵A为对称矩阵

{ <mstyle displaystyle="false" scriptlevel="0"> b 1 + 2 λ ( A 1 , 1 x 1 + A 1 , 2 x 2 + . . . + A 1 , n x n ) = 0 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> b 2 + 2 λ ( A 2 , 1 x 1 + A 2 , 2 x 2 + . . . + A 2 , n x n ) = 0 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> . </mstyle> <mstyle displaystyle="false" scriptlevel="0"> . </mstyle> <mstyle displaystyle="false" scriptlevel="0"> . </mstyle> <mstyle displaystyle="false" scriptlevel="0"> b n + 2 λ ( A n , 1 x 1 + A n , 2 x 2 + . . . + A n , n x n ) = 0 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> i = 1 n j = 1 n A i , j x i x j = 1 </mstyle> \begin{cases} b1+2*\lambda(A_{1,1}x_1+A_{1,2}x_2+...+A_{1,n}x_n)=0\\b2+ 2*\lambda(A_{2,1}x_1+A_{2,2}x_2+...+A_{2,n}x_n)=0\\.\\.\\.\\bn+ 2*\lambda(A_{n,1}x_1+A_{n,2}x_2+...+A_{n,n}x_n)=0\\\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j=1\\\end{cases} b1+2λ(A1,1x1+A1,2x2+...+A1,nxn)=0b2+2λ(A2,1x1+A2,2x2+...+A2,nxn)=0...bn+2λ(An,1x1+An,2x2+...+An,nxn)=0i=1nj=1nAi,jxixj=1


{ <mstyle displaystyle="false" scriptlevel="0"> B + 2 λ A x = 0 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> x T A x = 1 </mstyle> \begin{cases}B+2\lambda Ax=0①\\x^TAx=1②\end{cases} {B+2λAx=0xTAx=1

B + 2 λ A x = 0 B+2\lambda Ax=0 B+2λAx=0----> 2 λ A x = B 2\lambda Ax=-B 2λAx=B----> x = <mstyle displaystyle="true" scriptlevel="0"> A 1 2 λ </mstyle> B x=-\dfrac{A^{-1}}{2\lambda}*B x=2λA1B

x i b i = x T B = B T x \sum x_ib_i=x^TB=B^Tx xibi=xTB=BTx----> B T x = B T ( <mstyle displaystyle="true" scriptlevel="0"> A 1 2 λ </mstyle> B ) = x T B B^Tx=B^T(-\dfrac{A^{-1}}{2\lambda}*B)=x^TB BTx=BT(2λA1B)=xTB----> x T = B T <mstyle displaystyle="true" scriptlevel="0"> A 1 2 λ </mstyle> x^T=-B^T\dfrac{A^{-1}}{2\lambda} xT=BT2λA1

x T A x = 1 x^TAx=1 xTAx=1----> B T <mstyle displaystyle="true" scriptlevel="0"> A 1 2 λ </mstyle> A <mstyle displaystyle="true" scriptlevel="0"> A 1 2 λ </mstyle> B = 1 -B^T\dfrac{A^{-1}}{2\lambda}*A*-\dfrac{A^{-1}}{2\lambda}*B=1 BT2λA1A2λA1B=1----> <mstyle displaystyle="true" scriptlevel="0"> 1 4 λ 2 </mstyle> B T A 1 B = 1 \dfrac{1}{4\lambda ^2}B^TA^{-1}B=1 4λ21BTA1B=1

( B T x ) 2 = ( <mstyle displaystyle="true" scriptlevel="0"> 1 2 λ </mstyle> B T A 1 B ) 2 = <mstyle displaystyle="true" scriptlevel="0"> 1 4 λ 2 </mstyle> ( B T A 1 B ) ( B T A 1 B ) = ( B T A 1 B ) (\sum B^Tx)^2=(-\dfrac{1}{2\lambda}B^TA^{-1}B)^2=\dfrac{1}{4\lambda ^2}(B^TA^{-1}B)(B^TA^{-1}B)=(B^TA^{-1}B) (BTx)2=(2λ1BTA1B)2=4λ21(BTA1B)(BTA1B)=(BTA1B)

求解 B T A 1 B B^TA^{-1}B BTA1B即可。

代码:

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=210;
const int maxm=100100;
const int up=100000;

struct node
{
    int n,m;
    int a[maxn][maxn];
    void init(void)
    {
        memset(a,0,sizeof(a));
        for(int i=1;i<=n;i++)
            a[i][i]=1;
    }
    void input(void)
    {
        for(int i=1;i<=n;i++)
        {
            for(int j=1;j<=m;j++)
                scanf("%d",&a[i][j]);
        }
    }
    void _swap(int x,int y)
    {
        for(int i=1;i<=n;i++)
            swap(a[x][i],a[y][i]);
    }
    void mul_k(int x,int k)
    {
        for(int i=1;i<=n;i++)
            a[x][i]=(ll)a[x][i]*k%mod;
    }
    void mul_k_add(int x,int k,int y)
    {
        for(int i=1;i<=n;i++)
            a[y][i]=((a[y][i]+(ll)a[x][i]*k)%mod+mod)%mod;
    }
    void print(void)
    {
        for(int i=1;i<=n;i++)
        {
            for(int j=1;j<=m;j++)
                printf("%d ",a[i][j]);
            putchar('\n');
        }
    }
    node getT(void)
    {
        node ans;
        ans.n=m,ans.m=n;
        for(int i=1;i<=m;i++)
        {
            for(int j=1;j<=n;j++)
                ans.a[i][j]=a[j][i];
        }
        return ans;
    }

    node operator * (const node &b) const
    {
        node ans;
        memset(ans.a,0,sizeof(ans.a));
        ans.n=n,ans.m=b.m;
        for(int i=1;i<=n;i++)
        {
            for(int j=1;j<=b.m;j++)
            {
                for(int k=1;k<=m;k++)
                    ans.a[i][j]=(ans.a[i][j]+1ll*a[i][k]*b.a[k][j])%mod;
            }
        }
        return ans;
    }

}a,inva,b,bt,ans;

int mypow(int a,int b)
{
    int ans=1;
    while(b)
    {
        if(b&1) ans=(ll)ans*a%mod;
        a=(ll)a*a%mod;
        b>>=1;
    }
    return ans;
}

void get(node &a,node &b)
{
    b.n=b.m=a.n;
    b.init();
    int n=a.n;
    for(int i=1;i<=n;i++)
    {
        if(!a.a[i][i])
        {
            for(int j=i+1;j<=n;j++)
            {
                if(a.a[j][i])
                {
                    a._swap(i,j);
                    b._swap(i,j);
                    break;
                }
            }
        }

        b.mul_k(i,mypow(a.a[i][i],mod-2));
        a.mul_k(i,mypow(a.a[i][i],mod-2));

        for(int j=i+1;j<=n;j++)
        {
            b.mul_k_add(i,-a.a[j][i],j);
            a.mul_k_add(i,-a.a[j][i],j);
        }
    }
    for(int i=n;i>=1;i--)
    {
        for(int j=i-1;j>=1;j--)
        {
            b.mul_k_add(i,-a.a[j][i],j);
            a.mul_k_add(i,-a.a[j][i],j);
        }
    }
}

int main(void)
{
    int n;
    while(scanf("%d",&n)!=EOF)
    {
        a.n=n,a.m=n;
        a.input();
        b.n=n,b.m=1;
        b.input();
        get(a,inva);
        printf("%d\n",(b.getT()*inva*b).a[1][1]);
    }

    return 0;
}