传送门:http://codeforces.com/problemset/problem/573/d
思路:首先如果没有限制,那么根据排序不等式,肯定按顺序匹配战士和马最好。
但是现在有了战士不能和自己的马匹配的限制。
于是就有了一个重要的性质:
最优匹配的前提下,排序后第i号战士只会与[i-2,i+2]号马匹配
至于证明,可以自己YY,也可以分情况讨论(好像很复杂...)
于是就可以DP了,设f[i]表示1-i号战士正好和1-i号战马匹配
记ban[i]表示第i号战士不能匹配的战马
那么转移方程就是:
f[i]=max{
f[i-1]+a[i]*b[i];(ban[i]!=i)
f[i-2]+a[i]*b[i-1]+a[i-1]*b[i];(ban[i]!=i-1)
f[i-3]+a[i]*b[i-2]+a[i-1]*b[i]+a[i-2]*b[i-1](ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1)
f[i-3]+a[i]*b[i-1]+a[i-1]*b[i-2]+a[i-2]*b[i](ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i)
}
然后每次DP一遍,复杂度O(qn),加一些常数优化就可以过了。
#include<cstdio>
#include<cstring>
#include<algorithm>
typedef long long ll;
const ll inf=1ll<<60;
const int maxn=30010;
using namespace std;
struct node{
ll v;int id;//价值,不能匹配的编号
}a[maxn],b[maxn];
int n,q,posa[maxn],posb[maxn],ban[maxn];ll f[maxn],w1[maxn],w2[maxn],w3[maxn];
bool cmp(node a,node b){return a.v<b.v;}
void calc(int i){
w1[i]=w2[i]=w3[i]=-inf;
if (i>=1&&ban[i]!=i) w1[i]=a[i].v*b[i].v;
if (i>=2&&ban[i]!=i-1&&ban[i-1]!=i) w2[i]=a[i].v*b[i-1].v+a[i-1].v*b[i].v;
if (i>=3){
if (ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1)
w3[i]=a[i].v*b[i-2].v+a[i-1].v*b[i].v+a[i-2].v*b[i-1].v;
if (ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i)
w3[i]=max(w3[i],a[i].v*b[i-1].v+a[i-1].v*b[i-2].v+a[i-2].v*b[i].v);
}
}
int main(){
scanf("%d%d",&n,&q);
for (int i=1;i<=n;i++) scanf("%I64d",&a[i].v),a[i].id=i;
for (int i=1;i<=n;i++) scanf("%I64d",&b[i].v),b[i].id=i;
sort(a+1,a+1+n,cmp),sort(b+1,b+1+n,cmp);
for (int i=1;i<=n;i++) posa[a[i].id]=i,posb[b[i].id]=i;
for (int i=1;i<=n;i++) ban[i]=posb[a[i].id];
for (int i=1;i<=n;i++) calc(i);
// for (int i=1;i<=n;i++) printf("%d\n",ban[i]);
for (int j=1,x,y;j<=q;j++){
scanf("%d%d",&x,&y);
x=posa[x],y=posa[y],swap(ban[x],ban[y]);
for (int i=max(1,x-5);i<=min(n,x+5);i++) calc(i);
for (int i=max(1,y-5);i<=min(n,y+5);i++) calc(i);
f[0]=0;
for (int i=1;i<=n;i++){
if (i>=1) f[i]=f[i-1]+w1[i];
if (i>=2) f[i]=max(f[i],f[i-2]+w2[i]);
if (i>=3) f[i]=max(f[i],f[i-3]+w3[i]);
}
printf("%I64d\n",f[n]);
}
return 0;
}
/*
4 15
70 46 78 69
90 93 83 11
2 3
3 4
4 1
3 1
4 3
3 1
2 4
3 1
3 2
3 4
2 3
1 2
1 4
4 1
1 2
*/
然而还有更优的写法,用矩阵乘法+线段树优化
观察转移方程
f[i]=max{
f[i-1]+a[i]*b[i]; (ban[i]!=i)
f[i-2]+a[i]*b[i-1]+a[i-1]*b[i]; (ban[i]!=i-1)
f[i-3]+a[i]*b[i-2]+a[i-1]*b[i]+a[i-2]*b[i-1]; (ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1)
f[i-3]+a[i]*b[i-1]+a[i-1]*b[i-2]+a[i-2]*b[i]; (ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i)
}
可以发现f[i]只与f[i-1],f[i-2],f[i-3]有关我们于是联想到可以用一个3*3的矩阵来转移
而每次修改又只会影响几个转移矩阵
于是可以用线段树维护区间矩阵乘法,修改时就暴力改动那几个矩阵
但是这里的矩阵乘法和平时的乘法有些不同
现在乘法的定义是:c=a*b
c[i][j]=max(a[i][k]+b[k][j])
而正常的定义是
c[i][j]=Σ(a[i][k]*b[k][j])
我们现在只要证它有结合律就可以用线段树来维护区间乘法
令F=(a*b)*c,G=a*(b*c)
那么就有
F[i][j]=max(max(a[i][u]+b[u][v])+c[v][j])
G[i][j]=max(a[i][u]+max(b[u][v]+c[v][j]))
因为max和+有结合律,+对max有分配律
max((max(a,b)),c)=max(a,max(b,c))
(a+b)+c=a+(b+c)
max(a,b)+c=max(a+c,b+c)
于是G[i][j]=max(max(a[i][u]+b[u][v])+c[v][j])=F[i][j]
那么我们就证明了这种矩阵乘法有结合律
于是就可以上线段树解决了
具体细节:
如果不能转移就填-1
转移矩阵就是
-1 -1 case(i-2)
0 -1 case(i-1)
-1 0 case(i)
-1 -1 -1 *-1 -1 case(i-2)=-1 -1 -1
-1 -1 -1 0 -1 case(i-1) -1 -1 -1
f[i-2] f[i-1] f[i] -1 0 case(i) f[i-1] f[i] f[i-2]
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ls (p<<1)
#define rs ((p<<1)|1)
#define mid ((l+r)>>1)
const int maxn=30010,maxt=maxn<<2;
typedef long long ll;
using namespace std;
int n,q,posa[maxn],posb[maxn],ban[maxn];
struct node{ll v;int id;}a[maxn],b[maxn];
bool cmp(node a,node b){return a.v<b.v;}
struct matrix{
ll mat[3][3];
void clear(){memset(mat,-1,sizeof(mat));}
};
matrix operator *(matrix a,matrix b){
matrix res;res.clear();
for (int i=0;i<3;i++)
for (int k=0;k<3;k++) if (a.mat[i][k]!=-1)
for (int j=0;j<3;j++) if (b.mat[k][j]!=-1)
res.mat[i][j]=max(res.mat[i][j],a.mat[i][k]+b.mat[k][j]);
return res;
}
matrix get_matrix(int i){
matrix c;c.clear();
c.mat[1][0]=c.mat[2][1]=0;
if (ban[i]!=i) c.mat[2][2]=a[i].v*b[i].v;
if (i<=1) return c;
if (ban[i]!=i-1) c.mat[1][2]=a[i].v*b[i-1].v+a[i-1].v*b[i].v;
if (i<=2) return c;
ll v1=-1,v2=-1;
if (ban[i]!=i-1&&ban[i-1]!=i-2&&ban[i-2]!=i) v1=a[i].v*b[i-1].v+a[i-1].v*b[i-2].v+a[i-2].v*b[i].v;
if (ban[i]!=i-2&&ban[i-1]!=i&&ban[i-2]!=i-1) v2=a[i].v*b[i-2].v+a[i-1].v*b[i].v+a[i-2].v*b[i-1].v;
c.mat[0][2]=max(v1,v2);
return c;
}
struct Segment_Tree{
matrix t[maxt];
void build(int p,int l,int r){
//printf("%d %d %d\n",p,l,r);
if (l==r){t[p]=get_matrix(l);return;}
build(ls,l,mid),build(rs,mid+1,r);
t[p]=t[ls]*t[rs];
}
void modify(int p,int l,int r,int a){
if (l==r){t[p]=get_matrix(l);return;}
if (a<=mid) modify(ls,l,mid,a);
else modify(rs,mid+1,r,a);
t[p]=t[ls]*t[rs];
}
void modify(int a){modify(1,1,n,a);}
ll query(){return t[1].mat[2][2];}
}T;
int main(){
scanf("%d%d",&n,&q);
for (int i=1;i<=n;i++) scanf("%I64d",&a[i].v),a[i].id=i;
for (int i=1;i<=n;i++) scanf("%I64d",&b[i].v),b[i].id=i;
sort(a+1,a+1+n,cmp),sort(b+1,b+1+n,cmp);
for (int i=1;i<=n;i++) posa[a[i].id]=i,posb[b[i].id]=i;
for (int i=1;i<=n;i++) ban[i]=posb[a[i].id];
T.build(1,1,n);
/*for (int i=1;i<=n;i++,puts("")){
printf("%d\n",i);matrix c=get_matrix(i);
for (int j=0;j<3;j++,puts(""))
for (int k=0;k<3;k++) printf("%lld ",c.mat[j][k]);
}*/
for (int x,y;q;q--){
scanf("%d%d",&x,&y),x=posa[x],y=posa[y],swap(ban[x],ban[y]);
for (int i=max(1,x-2);i<=min(x+2,n);i++) T.modify(i);
for (int i=max(1,y-2);i<=min(y+2,n);i++) T.modify(i);
printf("%I64d\n",T.query());
}
return 0;
}