题目链接
题目解法
首先可以得到一个显然的
d
p
dp
dp 方程
d
p
i
=
m
i
n
{
d
p
j
+
(
i
−
j
)
2
+
m
i
n
{
a
i
,
a
i
+
1
,
.
.
.
,
a
j
}
}
dp_i=min\{dp_j+(i-j)^2+min\{a_i,a_{i+1},...,a_j\}\}
dpi=min{dpj+(i−j)2+min{ai,ai+1,...,aj}}
因为
m
i
n
min
min 不好斜率优化,且难以用线段树维护
考虑值域
≤
n
\le n
≤n,所以一个很长的区间有很大可能
m
i
n
min
min 很小,所以考虑在值域上做文章
考虑如果
i
i
i 的最有转移点为
j
j
j,那么
m
i
n
{
a
i
,
.
.
.
,
a
j
}
∗
l
e
n
2
<
n
∗
l
e
n
min\{a_i,...,a_j\}*len^2<n*len
min{ai,...,aj}∗len2<n∗len
因为每个权值最多为
n
n
n,一个一个跳的权值最大为
n
∗
l
e
n
n*len
n∗len
所以
m
i
n
{
a
i
,
.
.
.
,
a
j
}
<
n
l
e
n
min\{a_i,...,a_j\}<\frac{n}{len}
min{ai,...,aj}<lenn
考虑值域上根号分治
- 区间最小值 > n >\sqrt n >n,那么该区间长度最长为 n \sqrt n n,时间复杂度 O ( n n ) O(n\sqrt n) O(nn)
- 区间最小值
≤
n
\le \sqrt n
≤n
考虑对于 j < k < i j<k<i j<k<i,若 j j j 为 i i i 的最优转移点,且 a k = m i n { a i , . . . , a j } a_k=min\{a_i,...,a_j\} ak=min{ai,...,aj}
那么 a k ∗ ( i − k ) 2 + a k ∗ ( k − j ) 2 < a k ∗ ( i − j ) 2 a_k*(i-k)^2+a_k*(k-j)^2<a_k*(i-j)^2 ak∗(i−k)2+ak∗(k−j)2<ak∗(i−j)2
所以 k k k 比 j j j 更优,矛盾
这说明如果 j j j 为 i i i 的最优转移点,那么 a i = m i n { a i , . . . , a j } a_i=min\{a_i,...,a_j\} ai=min{ai,...,aj} 或 a j = m i n { a i , . . . , a j } a_j=min\{a_i,...,a_j\} aj=min{ai,...,aj}
对于 i i i 为 m i n min min 的情况,之间在后面暴力找第一个 a j < = a i a_j<=a_i aj<=ai 即可,考虑对于 1 , . . . , n 1,...,\sqrt n 1,...,n,每个数只会被扫一次 1 − n 1-n 1−n,所以时间复杂度 O ( n n ) O(n\sqrt n) O(nn)
对于 j j j 为 m i n min min 的情况,用栈记录之前所有 ≤ a i \le a_i ≤ai 的数即可,栈中最多 n \sqrt n n 个元素,时间复杂度 O ( n n ) O(n\sqrt n) O(nn)
总的时间复杂度 O ( n n ) O(n\sqrt n) O(nn)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N(400100);
int n,a[N],dp[N];
int stk[N],top;
inline int read(){
int FF=0,RR=1;
char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
return FF*RR;
}
signed main(){
n=read();
for(int i=1;i<=n;i++) a[i]=read();
memset(dp,0x3f,sizeof(dp));
dp[1]=0,stk[++top]=1;
int B=sqrt(n)+2;
for(int i=2;i<=n;i++){
for(int j=1,mn=a[i];j<=B+5;j++){
if(i-j<1) break;
mn=min(mn,a[i-j]);
dp[i]=min(dp[i],dp[i-j]+j*j*mn);
}
for(int j=1;j<=top;j++) dp[i]=min(dp[i],dp[stk[j]]+a[stk[j]]*(i-stk[j])*(i-stk[j]));
if(a[i]<=B){
while(top&&a[i]<=a[stk[top]]) top--;
stk[++top]=i;
for(int j=i-1;j>=1;j--){
if(a[j]<=a[i]) break;
dp[i]=min(dp[i],dp[j]+a[i]*(i-j)*(i-j));
}
}
}
for(int i=1;i<=n;i++) printf("%lld ",dp[i]);
return 0;
}