思路:
斜率优化。
设f[i]表示将前i个分组的最优值,则有转移方程式:
f[i]=max{ f[j]+a*(s[i]-s[j])^2+b*(s[i]-s[j])+c }
经过化简得到:f[i]=max{ (f[j]+a*s[j]^2-b*s[j])-2*a*s[i]*s[j] } + a*s[i]^2+b*s[i]+c
单调队列维护上凸包即可。
y[j] = (f[j]+a*s[j]^2-b*s[j])
x[j] = s[j]
min p = y[j]-2*a*s[i]*x[j] 因为a是负的 所以斜率为正 是上凸包
now.x = s[j]
now.y = y[i] = (f[i]+a*s[i]^2-b*s[i]) = {(f[j]+a*s[j]^2-b*s[j])-2*a*s[i]*s[j]+ a*s[i]^2+b*s[i]+c} + a*s[i]^2-b*s[i] = {(q[L].y)-2*a*s[i]*q[L].x} + 2*a*s[i]^2+c
答案就是 q[R].y-a*s[n]*s[n]+b*s[n]
代码一:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mem(a) memset(a,0,sizeof(a))
#define mp(x,y) make_pair(x,y)
const int INF = 0x3f3f3f3f;
const ll INFLL = 0x3f3f3f3f3f3f3f3fLL;
inline ll read(){
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
//
const int maxn = 1e6+10;
struct node{
ll x,y;
}now,q[maxn];
int n,x;
ll a,b,c,s[maxn];
ll cross(node A,node B,node C){
return (B.x-A.x)*(C.y-A.y) - (C.x-A.x)*(B.y-A.y);
}
int main(){
n=read();
a=read(),b=read(),c=read();
s[0] = 0;
for(int i=1; i<=n; i++){
x=read();
s[i] = s[i-1]+x;
}
int L=0,R=0;
for(int i=1; i<=n; i++){
while(L<R && q[L+1].y-2*a*s[i]*q[L+1].x >= q[L].y-2*a*s[i]*q[L].x) L++;
while(L<R && q[L].y-2*a*s[i]*q[L].x <= q[L+1].y-2*a*s[i]*q[L+1].x) L++;
now.x = s[i];
now.y = q[L].y-2*a*s[i]*q[L].x+2*a*s[i]*s[i]+c;
while(L<R && cross(q[R-1],now,q[R]) <= 0) R--; // 为什么now跑到中间去嘞?是上凸包
q[++R] = now;
}
cout << q[R].y-a*s[n]*s[n]+b*s[n] << endl;
return 0;
}
代码二:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mem(a) memset(a,0,sizeof(a))
#define mp(x,y) make_pair(x,y)
const int INF = 0x3f3f3f3f;
const ll INFLL = 0x3f3f3f3f3f3f3f3fLL;
inline ll read(){
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
//
const int maxn = 1e6+10;
int n,x,q[maxn];
ll a,b,c,s[maxn],f[maxn];
ll getup(int j,int k){
return f[j]-f[k]+a*(s[j]*s[j]-s[k]*s[k])+b*(s[k]-s[j]);
}
ll getdown(int j,int k){
return (s[j]-s[k]);
}
int main(){
n=read();
a=read(),b=read(),c=read();
s[0] = 0;
for(int i=1; i<=n; i++){
x=read();
s[i] = s[i-1]+x;
}
int L=0,R=0;
for(int i=1; i<=n; i++){
while(L<R && getup(q[L+1],q[L]) >= s[i]*2*a*getdown(q[L+1],q[L])) L++; // a是负的
int j = q[L];
f[i] = f[j] + a*(s[i]-s[j])*(s[i]-s[j]) + b*(s[i]-s[j]) + c;
while(L<R && getup(i,q[R])*getdown(q[R],q[R-1]) >= getup(q[R],q[R-1])*getdown(i,q[R])) R--;
q[++R] = i;
}
cout << f[n] << endl;
return 0;
}