题意: n × m n\times m n×m 的网格图,每个点有两个权值 v a l i , j , b u f i , j val_{i,j},buf_{i,j} vali,j,bufi,j,从 ( 1 , 1 ) (1,1) (1,1) 开始只能向下或向右走到 ( n , m ) (n,m) (n,m) ,在某个位置时可以选择触发该位置的事件(也可不触发),将获得 v a l i , j val_{i,j} vali,j 的得分和 b u f i , j buf_{i,j} bufi,j 的 buff,该 buff 会在之后每一步产生对应的分数直到触发下一个事件。求最大分数。
n m ≤ 1 0 5 nm\leq 10^5 nm≤105,保证 v a l 1 , 1 = v a l n , m = b u f 1 , 1 = b u f n , m = 0 val_{1,1}=val_{n,m}=buf_{1,1}=buf_{n,m}=0 val1,1=valn,m=buf1,1=bufn,m=0
显然可以 dp,设 f ( x , y ) f(x,y) f(x,y) 为走到 ( x , y ) (x,y) (x,y) 并触发事件的最大分数。
转移枚举上一个事件
f ( x , y ) = max i ≤ x , j ≤ y { f ( i , j ) + b u f ( i , j ) ( x + y − i − j ) } + v a l ( i , j ) f(x,y)=\max_{i\leq x,j\leq y}\{f(i,j)+buf(i,j)(x+y-i-j)\}+val(i,j) f(x,y)=i≤x,j≤ymax{f(i,j)+buf(i,j)(x+y−i−j)}+val(i,j)
显然可以斜率优化
这个状态有两维,发现枚举的位置是满足二维偏序,所以可以 cdq 分治搞掉一维,另一维上李超线段树即可。
注意 0 0 0 和 − ∞ -\infin −∞ 的不同。
复杂度 O ( n m log 2 n ) O(nm\log ^2n) O(nmlog2n)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
using namespace std;
inline int read()
{
int ans=0,f=1;
char c=getchar();
while (!isdigit(c)) (c=='-')&&(f=-1),c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return f*ans;
}
const int N=2e5,MAXN=N+5,INF=2e9;
int n,m;
vector<int> val[MAXN],buf[MAXN],f[MAXN];
int k[MAXN],b[MAXN],tot;
inline int calc(const int& i,const int& x){return k[i]*x+b[i];}
int ch[MAXN][2],mx[MAXN],rt,cnt;
inline void clear(){tot=0,rt=cnt=1,ch[1][0]=ch[1][1]=mx[1]=0;}
void modify(int& p,int l,int r,int v)
{
if (!p) p=++cnt,ch[p][0]=ch[p][1]=mx[p]=0;
if (!mx[p]) return (void)(mx[p]=v);
if (calc(mx[p],l)>=calc(v,l)&&calc(mx[p],r)>=calc(v,r)) return;
if (calc(mx[p],l)<=calc(v,l)&&calc(mx[p],r)<=calc(v,r)) return (void)(mx[p]=v);
int mid=(l+r)>>1;
if (calc(v,mid)>calc(mx[p],mid)) swap(v,mx[p]);
if (calc(v,l)>calc(mx[p],l)) modify(ch[p][0],l,mid,v);
if (calc(v,r)>calc(mx[p],r)) modify(ch[p][1],mid+1,r,v);
}
int query(int p,int l,int r,int k)
{
if (!p) return -INF;
int ans=(mx[p]? calc(mx[p],k):-INF);
int mid=(l+r)>>1;
if (k<=mid) ans=max(ans,query(ch[p][0],l,mid,k));
else ans=max(ans,query(ch[p][1],mid+1,r,k));
return ans;
}
void solve(int l,int r)
{
if (l==r)
{
clear();
f[l][1]+=val[l][1];
for (int i=1;i<=m;i++)
{
if (i>1) f[l][i]=max(f[l][i],query(rt,1,N,i))+val[l][i];
++tot,k[tot]=buf[l][i],b[tot]=f[l][i]-i*k[tot];
modify(rt,1,N,tot);
}
return;
}
int mid=(l+r)>>1;
solve(l,mid);
clear();
for (int j=1;j<=m;j++)
{
for (int i=l;i<=mid;i++)
{
++tot,k[tot]=buf[i][j],b[tot]=f[i][j]-(i+j)*k[tot];
modify(rt,1,N,tot);
}
for (int i=mid+1;i<=r;i++) f[i][j]=max(f[i][j],query(rt,1,N,i+j));
}
solve(mid+1,r);
}
int main()
{
n=read(),m=read();
for (int i=1;i<=n;i++)
{
buf[i].resize(m+1);
for (int j=1;j<=m;j++) buf[i][j]=read();
}
for (int i=1;i<=n;i++)
{
val[i].resize(m+1),f[i].resize(m+1,-INF);
for (int j=1;j<=m;j++) val[i][j]=read();
}
f[1][1]=0;
solve(1,n);
printf("%d\n",f[n][m]);
return 0;
}