题面:
题意:
给定一些约束条件求目标函数的最大值。
其中 det(A) ≠ 0 (mod 998244353),保证了在 mod 998244353 下矩阵A 可逆。
题解:
没有想明白为什么会在
∑
i
=
1
n
∑
j
=
1
n
A
i
,
j
x
i
x
j
=
1
\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j=1
∑i=1n∑j=1nAi,jxixj=1 的条件下计算目标函数的最大值。
我们假设目标函数为 f ( x 1 , . . . , x n ) = ∑ i = 1 n b i x i f(x_1,...,x_n)=\sum_{i=1}^nb_ix_i f(x1,...,xn)=∑i=1nbixi,因为最终求的是平方,那么一定在 f f f取极值时,最终答案取极值
约束条件为 g ( x 1 , . . . , x n ) = ∑ i = 1 n ∑ j = 1 n A i , j x i x j = 1 g(x_1,...,x_n)=\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j=1 g(x1,...,xn)=∑i=1n∑j=1nAi,jxixj=1
拉个朗日函数为 L ( x 1 , . . . , x n , λ ) = ∑ i = 1 n b i x i + λ ( ∑ i = 1 n ∑ j = 1 n A i , j x i x j − 1 ) L(x1,...,xn,\lambda)=\sum_{i=1}^nb_ix_i+\lambda(\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j-1) L(x1,...,xn,λ)=∑i=1nbixi+λ(∑i=1n∑j=1nAi,jxixj−1)
对L的每个变量求偏导,求偏导的时候
∑
i
=
1
n
∑
j
=
1
n
A
i
,
j
x
i
x
j
\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j
∑i=1n∑j=1nAi,jxixj拆开即可。
注意
A
i
,
j
=
A
j
,
i
A_{i,j}=A_{j,i}
Ai,j=Aj,i,矩阵A为对称矩阵
{ b 1 + 2 ∗ λ ( A 1 , 1 x 1 + A 1 , 2 x 2 + . . . + A 1 , n x n ) = 0 b 2 + 2 ∗ λ ( A 2 , 1 x 1 + A 2 , 2 x 2 + . . . + A 2 , n x n ) = 0 . . . b n + 2 ∗ λ ( A n , 1 x 1 + A n , 2 x 2 + . . . + A n , n x n ) = 0 ∑ i = 1 n ∑ j = 1 n A i , j x i x j = 1 \begin{cases} b1+2*\lambda(A_{1,1}x_1+A_{1,2}x_2+...+A_{1,n}x_n)=0\\b2+ 2*\lambda(A_{2,1}x_1+A_{2,2}x_2+...+A_{2,n}x_n)=0\\.\\.\\.\\bn+ 2*\lambda(A_{n,1}x_1+A_{n,2}x_2+...+A_{n,n}x_n)=0\\\sum_{i=1}^n\sum_{j=1}^nA_{i,j}x_ix_j=1\\\end{cases} ⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧b1+2∗λ(A1,1x1+A1,2x2+...+A1,nxn)=0b2+2∗λ(A2,1x1+A2,2x2+...+A2,nxn)=0...bn+2∗λ(An,1x1+An,2x2+...+An,nxn)=0∑i=1n∑j=1nAi,jxixj=1
即
{
B
+
2
λ
A
x
=
0
①
x
T
A
x
=
1
②
\begin{cases}B+2\lambda Ax=0①\\x^TAx=1②\end{cases}
{B+2λAx=0①xTAx=1②
B + 2 λ A x = 0 B+2\lambda Ax=0 B+2λAx=0----> 2 λ A x = − B 2\lambda Ax=-B 2λAx=−B----> x = − A − 1 2 λ ∗ B x=-\dfrac{A^{-1}}{2\lambda}*B x=−2λA−1∗B
∑ x i b i = x T B = B T x \sum x_ib_i=x^TB=B^Tx ∑xibi=xTB=BTx----> B T x = B T ( − A − 1 2 λ ∗ B ) = x T B B^Tx=B^T(-\dfrac{A^{-1}}{2\lambda}*B)=x^TB BTx=BT(−2λA−1∗B)=xTB----> x T = − B T A − 1 2 λ x^T=-B^T\dfrac{A^{-1}}{2\lambda} xT=−BT2λA−1
x T A x = 1 x^TAx=1 xTAx=1----> − B T A − 1 2 λ ∗ A ∗ − A − 1 2 λ ∗ B = 1 -B^T\dfrac{A^{-1}}{2\lambda}*A*-\dfrac{A^{-1}}{2\lambda}*B=1 −BT2λA−1∗A∗−2λA−1∗B=1----> 1 4 λ 2 B T A − 1 B = 1 \dfrac{1}{4\lambda ^2}B^TA^{-1}B=1 4λ21BTA−1B=1
( ∑ B T x ) 2 = ( − 1 2 λ B T A − 1 B ) 2 = 1 4 λ 2 ( B T A − 1 B ) ( B T A − 1 B ) = ( B T A − 1 B ) (\sum B^Tx)^2=(-\dfrac{1}{2\lambda}B^TA^{-1}B)^2=\dfrac{1}{4\lambda ^2}(B^TA^{-1}B)(B^TA^{-1}B)=(B^TA^{-1}B) (∑BTx)2=(−2λ1BTA−1B)2=4λ21(BTA−1B)(BTA−1B)=(BTA−1B)
求解 B T A − 1 B B^TA^{-1}B BTA−1B即可。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=210;
const int maxm=100100;
const int up=100000;
struct node
{
int n,m;
int a[maxn][maxn];
void init(void)
{
memset(a,0,sizeof(a));
for(int i=1;i<=n;i++)
a[i][i]=1;
}
void input(void)
{
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
scanf("%d",&a[i][j]);
}
}
void _swap(int x,int y)
{
for(int i=1;i<=n;i++)
swap(a[x][i],a[y][i]);
}
void mul_k(int x,int k)
{
for(int i=1;i<=n;i++)
a[x][i]=(ll)a[x][i]*k%mod;
}
void mul_k_add(int x,int k,int y)
{
for(int i=1;i<=n;i++)
a[y][i]=((a[y][i]+(ll)a[x][i]*k)%mod+mod)%mod;
}
void print(void)
{
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
printf("%d ",a[i][j]);
putchar('\n');
}
}
node getT(void)
{
node ans;
ans.n=m,ans.m=n;
for(int i=1;i<=m;i++)
{
for(int j=1;j<=n;j++)
ans.a[i][j]=a[j][i];
}
return ans;
}
node operator * (const node &b) const
{
node ans;
memset(ans.a,0,sizeof(ans.a));
ans.n=n,ans.m=b.m;
for(int i=1;i<=n;i++)
{
for(int j=1;j<=b.m;j++)
{
for(int k=1;k<=m;k++)
ans.a[i][j]=(ans.a[i][j]+1ll*a[i][k]*b.a[k][j])%mod;
}
}
return ans;
}
}a,inva,b,bt,ans;
int mypow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=(ll)ans*a%mod;
a=(ll)a*a%mod;
b>>=1;
}
return ans;
}
void get(node &a,node &b)
{
b.n=b.m=a.n;
b.init();
int n=a.n;
for(int i=1;i<=n;i++)
{
if(!a.a[i][i])
{
for(int j=i+1;j<=n;j++)
{
if(a.a[j][i])
{
a._swap(i,j);
b._swap(i,j);
break;
}
}
}
b.mul_k(i,mypow(a.a[i][i],mod-2));
a.mul_k(i,mypow(a.a[i][i],mod-2));
for(int j=i+1;j<=n;j++)
{
b.mul_k_add(i,-a.a[j][i],j);
a.mul_k_add(i,-a.a[j][i],j);
}
}
for(int i=n;i>=1;i--)
{
for(int j=i-1;j>=1;j--)
{
b.mul_k_add(i,-a.a[j][i],j);
a.mul_k_add(i,-a.a[j][i],j);
}
}
}
int main(void)
{
int n;
while(scanf("%d",&n)!=EOF)
{
a.n=n,a.m=n;
a.input();
b.n=n,b.m=1;
b.input();
get(a,inva);
printf("%d\n",(b.getT()*inva*b).a[1][1]);
}
return 0;
}