思路:还是先列出普通的dp方程,dp[i][j]=min(dp[k][j-1]+s(i,k+1)) 表示到前i个站点,用了j次爆炸,得到的最小值,其中s(i,k+1)表示从k+1到i的连乘积。
直接做必然是n^3 因为需要枚举k。考虑到dp[i][j],当j为定值,它随着i单调递增的,所以考虑使用斜率优化来做。
在分析之前先处理一下任意区间的连乘积问题:
有两种解决方案: 1. 记录前缀和和 前缀平方和 (具体原因分析前几项就行)
2. 记录前缀和 和前缀区间乘积 (意思是每新加入一个元素,都会增加 a[i]*(sum[i-1])项,sum[i-1]代表前缀和)
下面就是一般的套路:
设y<k<i (还是那句话,考虑单调性不考虑定义域就是在刷流氓)
假定:dp[k][j-1] +c[i]- c[k] -sum[k]*(sum[i]-sum[k]) 优于 dp[y][j-1] +c[i] -c[y] -sum[y]*(sum[i] - sum[y])
其中 c[i]代表解决方案二中的 前缀区间 乘积。
那么必然满足:dp[k][j-1] +c[i]- c[k] -sum[k]*(sum[i]-sum[k]) < dp[y][j-1] +c[i] -c[y] -sum[y]*(sum[i] - sum[y]) (因为我们求的是最小值)
化简得:(dp[k][j-1] - c[k] +sum[k]*sum[k] -( dp[y][j-1] -c[y] +sum[y]*sum[y] ) ) / (sum[k] -sum[y]) < sum[i]
观察到sum[i]是单调递增的,而前面的式子可以看做是平面上两点构成的斜率,也就是说我们的斜率在不断上升,而我们需要求出最小值。所以取到最值的点必然是下凸的。(这是从坐标图上来分析的)
从数值的角度分析:当我们已经知道 k比 y优的时候,y就已经完全没有用处了,同时,当我们考虑将新点i加入时,也要考虑维护一个下凸值。也就是说 如果 y<k<i 如何 ky 的斜率大于 ik的斜率,那么k点必然是没有用的,具体分析可以参考前面的一篇博客。
下面就来分析这题的具体实现:首先这是一个二维的dp,我们对不同的j需要维护不同的单调队列 ,因为我们之前的分析,都是建立在j不变的情况下讨论的。
下面的代码是采用记录前缀区间乘积实现的(记录前缀平方和大概是乘0.5导致精度问题,所以wa了,思想是没问题的)
代码1:前缀区间乘积(已AC)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll inf=1e18;
const int maxn=1e3+7;
const int mod=1e9+7;
ll sum[maxn],sp[maxn];
ll dp[maxn][maxn];
int a[maxn];
int q[maxn];
int n,m;
ll solve(int k,int y,int j) //y<k,且k优于y
{
if(sum[k]==sum[y])
{
if(dp[k][j-1]-sp[k]<dp[y][j-1]-sp[y])
{
return -1;
}
else
{
return inf;
}
}
return ((dp[k][j-1]-sp[k]+sum[k]*sum[k])-(dp[y][j-1]-sp[y]+sum[y]*sum[y]))/
(sum[k]-sum[y]);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
while(scanf("%d%d",&n,&m)!=EOF&&(n+m))
{
sum[0]=sp[0]=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
sum[i]=sum[i-1]+a[i];
sp[i]=sp[i-1]+a[i]*sum[i-1];
}
int head,tail;
// for(int j=1;j<=m;j++)
// {
// dp[1][j]=a[1];
// }
for(int i=1;i<=n;i++)
{
dp[i][0]=sp[i];
}
for(int j=1;j<=m;j++)
{
head=tail=0;
q[0]=j;
for(int i=j+1;i<=n;i++)
{
while(head<tail&&solve(q[head+1],q[head],j)<sum[i])
{
head++;
}
dp[i][j]=dp[q[head]][j-1]+sp[i]-sp[q[head]]-sum[q[head]]*(sum[i]-sum[q[head]]);
while(head<tail&&solve(q[tail],q[tail-1],j)>solve(i,q[tail],j))
{
tail--;
}
q[++tail]=i;
}
}
printf("%lld\n",dp[n][m]);
}
return 0;
}
代码2:前缀平方和(WA了)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll inf=1e18;
const int maxn=1e3+7;
const int mod=1e9+7;
ll sum[maxn],sp[maxn];
ll dp[maxn][maxn];
int a[maxn];
int q[maxn];
int n,m;
ll solve(int k,int y,int j) //y<k,且k优于y
{
if(sum[k]==sum[y])
{
if(dp[k][j-1]+0.5*sp[k]<dp[y][j-1]+0.5*sp[y])
{
return -1;
}
else
{
return inf;
}
}
return (dp[k][j-1]+(sum[k]*sum[k]+sp[k])/2-(dp[y][j-1]+(sum[y]*sum[y]+sp[y])/2))/
(sum[k]-sum[y]);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
while(scanf("%d%d",&n,&m)!=EOF&&(n+m))
{
sum[0]=sp[0]=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
sum[i]=sum[i-1]+a[i];
sp[i]=a[i]*a[i]+sp[i-1];
}
int head,tail;
// for(int j=1;j<=m;j++)
// {
// dp[1][j]=a[1];
// }
for(int i=1;i<=n;i++)
{
dp[i][0]=sum[i]*sum[i]-sp[i];
}
for(int j=1;j<=m;j++)
{
head=tail=0;
q[0]=j;
for(int i=j+1;i<=n;i++)
{
while(head<tail&&solve(q[head+1],q[head],j)<sum[i])
{
head++;
}
dp[i][j]=dp[q[head]][j-1]+((sum[i]-sum[q[head]])*(sum[i]-sum[q[head]])-(sp[i]-sp[q[head]]))/2;
while(head<tail&&solve(q[tail],q[tail-1],j)>solve(i,q[tail],j))
{
tail--;
}
q[++tail]=i;
}
}
printf("%lld\n",dp[n][m]);
}
return 0;
}