QWQ菜的真实。
首先来看这个题。
很显然能得到一个朴素的
d
p
dp
dp柿子
d p [ i ] = m a x ( d p [ i ] , d p [ j ] + ( s u m [ i ] − s u m [ j ] ) 2 ) dp[i]=max(dp[i],dp[j]+(sum[i]-sum[j])^2) dp[i]=max(dp[i],dp[j]+(sum[i]−sum[j])2)
但是因为
n
≤
500000
n\le 500000
n≤500000,所以
n
2
n^2
n2一定是过不了的。
考虑应该怎么优化。
考虑什么时候存在一个 j > k 且 j 比 k 更 优 秀 j>k且j比k更优秀 j>k且j比k更优秀
d p [ j ] + ( s u m [ i ] − s u m [ j ] ) 2 < d p [ k ] + ( s u m [ i ] − s u m [ k ] ) 2 dp[j]+(sum[i]-sum[j])^2<dp[k]+(sum[i]-sum[k])^2 dp[j]+(sum[i]−sum[j])2<dp[k]+(sum[i]−sum[k])2
我们进行化简
2
×
s
[
i
]
×
(
s
[
j
]
−
s
[
k
]
)
>
d
p
[
j
]
+
s
u
m
[
j
]
2
−
d
p
[
k
]
−
s
u
m
[
k
]
2
2\times s[i] \times (s[j]-s[k]) > dp[j]+sum[j]^2-dp[k]-sum[k]^2
2×s[i]×(s[j]−s[k])>dp[j]+sum[j]2−dp[k]−sum[k]2
由于权值都是正数,所以
s
[
j
]
−
s
[
k
]
>
0
s[j]-s[k]>0
s[j]−s[k]>0
我们设
f
[
x
]
=
s
u
m
[
x
]
2
+
d
p
[
x
]
f[x]=sum[x]^2+dp[x]
f[x]=sum[x]2+dp[x]
则上述柿子等于
2
×
s
[
i
]
>
f
[
j
]
−
f
[
k
]
s
[
j
]
−
s
[
k
]
2\times s[i]>\frac{f[j]-f[k]}{s[j]-s[k]}
2×s[i]>s[j]−s[k]f[j]−f[k]
观察到右边这个柿子是一个斜率的形式。
我们可以直接用单调队列维护一个下凸壳。
对于每次插入一个点,运用叉积进行 c h e c k check check,保证斜率是单调不降的。
int chacheng(Point x,Point y)
{
return x.x*y.y-y.x*x.y;
}
bool count(Point i,Point j,Point k)
{
Point x,y;
x.x=(k.x-i.x);
x.y=(k.y-i.y);
y.x=(k.x-j.x);
y.y=(k.y-j.y);
if (chacheng(x,y)<=0) return true;
return false;
// if ((double)(k.y-j.y)/(double)(k.x-j.x)<(double)(j.y-i.y)/(double)(j.x-i.x)) return true;
//return false;
}
void push(Point x)
{
while (tail>=head+1 && count(q[tail-1],q[tail],x)) tail--;
q[++tail]=x;
}
删除的话,只需要通过上面那个柿子,若存在 q [ h e a d + 1 ] 比 q [ h e a d ] q[head+1]比q[head] q[head+1]比q[head]优秀,就弹出队首元素
void pop(int lim)
{
while (tail>=head+1 && (q[head+1].y-q[head].y)<=lim*(q[head+1].x-q[head].x)) head++;
}
剩下的就是 d p dp dp部分
qwq因为一些奇奇怪怪的问题
W
A
WA
WA了一上午
xtbl
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk make_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 1e6+1e2;
struct Point{
int x,y;
};
Point q[maxn];
int dp[maxn];
int sum[maxn];
int val[maxn];
int n,m;
int head=1,tail=0;
int chacheng(Point x,Point y)
{
return x.x*y.y-y.x*x.y;
}
bool count(Point i,Point j,Point k)
{
Point x,y;
x.x=(k.x-i.x);
x.y=(k.y-i.y);
y.x=(k.x-j.x);
y.y=(k.y-j.y);
if (chacheng(x,y)<=0) return true;
return false;
// if ((double)(k.y-j.y)/(double)(k.x-j.x)<(double)(j.y-i.y)/(double)(j.x-i.x)) return true;
//return false;
}
void push(Point x)
{
while (tail>=head+1 && count(q[tail-1],q[tail],x)) tail--;
q[++tail]=x;
}
void pop(int lim)
{
while (tail>=head+1 && (q[head+1].y-q[head].y)<=lim*(q[head+1].x-q[head].x)) head++;
}
signed main()
{
while (scanf("%lld%lld",&n,&m)!=EOF)
{
memset(q,0,sizeof(q));
memset(dp,0,sizeof(dp));
memset(sum,0,sizeof(sum));
head=1,tail=0;
//n=read();m=read();
for (int i=1;i<=n;i++) val[i]=read();
for (int i=1;i<=n;i++) sum[i]=sum[i-1]+val[i];
dp[0]=0;
push((Point){0,0});
for (int i=1;i<=n;i++)
{
pop(2ll*sum[i]);
dp[i]=q[head].y-q[head].x*q[head].x+m+(sum[i]-q[head].x)*(sum[i]-q[head].x);
push((Point){sum[i],dp[i]+sum[i]*sum[i]});
//cout<<i<<" "<<dp[i]<<" "<<q[head].x<<" "<<q[head].y<<" "<<head<<" "<<tail<<endl;
}
cout<<dp[n]<<"\n";
}
return 0;
}