I - Necklace
题目大意:给出一个环每一段的价值为区间和的平方,求将环断成k断的最小价值之和。
解题思路:容易想到n^4的dp
首先断环成链,将数组复制一遍。
用dp[i][j]表示前i个分成j断的最小代价,那么我们就枚举区间,枚举段数,没增加一个点,当前点要么独立成一段,要么跟前边的成一段,所以再枚举间断点。
由于起点不确定,所以我们要做n次的dp,时间复杂度为n^4
这个是超时的代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<bitset>
#include<ctime>
#include<map>
#include<set>
#include<algorithm>
#include<vector>
using namespace std;
#define LL long long
#define N 5000005
#define maxn 10005
#define inf 0x3f3f3f3f
#define sca(x) scanf("%d",&x)
#define pb(x) push_back(x)
LL a[405],b[405];
LL sum[405];
LL dp[405][405];
LL Sum(int l,int r)
{
if(l>r) return 0;
return (sum[r]-sum[l-1])*(sum[r]-sum[l-1]);
}
LL solve(int n,int k)
{
for(int i=1;i<=n;i++)
{
for(int j=1;j<=k;j++)dp[i][j]=inf;
}
for(int i=1;i<=n;i++)dp[i][0]=0;
for(int i=1; i<=n; i++)
{
for(int j=1; j<=k; j++)
{
if(i<j)continue;
if(j==1)dp[i][j]=Sum(1,i);
else
for(int l=i; l>=j-1; l--)
{
dp[i][j]=min(dp[i][j],dp[l][j-1]+Sum(l+1,i));
}
//cout<<i<<" "<<j<<" "<<dp[i][j]<<endl;
}
}
return dp[n][k];
}
int main()
{
int t;
sca(t);
while(t--)
{
int n,k;
scanf("%d%d",&n,&k);
sum[0]=0;
for(int i=1; i<=n; i++)scanf("%lld",&a[i]),a[n+i]=a[i];
LL ans=inf;
for(int i=1; i<=n; i++)
{
int cnt=1;
sum[0]=0;
for(int j=i; j<=i+n-1; j++)
{
b[cnt]=a[j];
sum[cnt]=sum[cnt-1]+a[j];
cnt++;
}
ans=min(ans,solve(n,k));
}
printf("%lld\n",ans);
}
}
当时只是知道可以优化但是不知道怎么去优化。后来学习一斜率优化。
首先我们知道
前i个分成j段
Sum(l,r)=(sum[r]-sum[l])*(sum[r]-sum[l]);
即dp[i][j]=dp[l][j-1]+Sum(l,r)^2 (l>=0 && r<i)
其中dp[i][j]要取最小值
展开 dp[i][j]=dp[l][j-1]+sum[r]^2+sum[l]^2-2*sum[r]sum[l]
移项 dp[i][j]+2*sum[r]sum[l]=sum[l][j-1]+sum[r]^2+sum[l]^2;
这个式子可以看做是 b+kx=y
因为2sum[r]为常数,可以看做是当前的斜率。sum[l]使我们要枚举的x点 sum[l][j-1]是我们在上层循环已经求过的值
可以看做是y点(sum[r]^2是一个常数,在求斜率的时候相减就抵消了,所以两边都是单变量)。
此时我们要求的dp[i][j]就是x=0时的截距。
我们希望截距尽量小。
所以维护一个下凸包。
如图所示,当我去找使刚才那个式子的截距最小的点对的时候,我们发现
若果当前队列中的前两个点的斜率小于当前的斜率,那么A点就可以不要了。因为在B点可以取得更优的值。
同理我们将队列前边的不符合要求的点都去掉,取得最优值。
然后我们要将现在的点加入,加入的时候也是同理,我们比较当前队列尾部的两个点的斜率和当前斜率(对于本题来说是2*sum[i])的关系,如果斜率大于当前斜率,我们发现C点也可以去掉了。
就这样不断去除两端的点,就得到了优化。
可以优化到n^3。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<bitset>
#include<ctime>
#include<map>
#include<set>
#include<algorithm>
#include<vector>
using namespace std;
#define N 5000005
#define maxn 10005
#define inf 0x3f3f3f3f
#define sca(x) scanf("%d",&x)
#define pb(x) push_back(x)
int a[405],b[405];
int sum[405];
int dp[205][205];
int q[1005];
int Sum(int l,int r)
{
if(l>r) return 0;
return (sum[r]-sum[l])*(sum[r]-sum[l]);
}
int X(int x,int y)
{
return 2*(sum[x]-sum[y]);
}
int Y(int x,int y,int k)
{
return (dp[x][k]+sum[x]*sum[x])-( dp[y][k]+sum[y]*sum[y]);
}
int solve(int n,int k)
{
for(int i=1;i<=n;i++)
{
for(int j=1;j<=k;j++)dp[i][j]=inf;
}
for(int i=1;i<=n;i++)dp[i][1]=Sum(0,i);
int head=0,tail=0;
for(int i=2;i<=k;i++)
{
head=tail=0;
q[tail++]=i-1;
for(int j=i;j<=n;j++)
{
while(head+1<tail &&Y(q[head+1],q[head],i-1)<=sum[j]*X(q[head+1],q[head]) )head++;
dp[j][i]=dp[q[head]][i-1]+Sum(q[head],j);
while(head+1<tail &&Y(q[tail-1],q[tail-2],i-1)*X(j,q[tail-1])>=Y(j,q[tail-1],i-1)*X(q[tail-1],q[tail-2]))tail--;
q[tail++]=j;
}
}
return dp[n][k];
}
int main()
{
//freopen("D:\in.txt", "r", stdin);
//freopen("D:\ougly.txt","w",stdout);
int t;
sca(t);
while(t--)
{
int n,k;
scanf("%d%d",&n,&k);
for(int i=1; i<=n; i++)scanf("%d",&a[i]),a[n+i]=a[i];
int ans=inf;
for(int i=1; i<=n; i++)
{
int cnt=1;
sum[0]=0;
for(int j=i; j<=i+n-1; j++)
{
b[cnt]=a[j];
sum[cnt]=sum[cnt-1]+a[j];
cnt++;
}
ans=min(ans,solve(n,k));
}
printf("%d\n",ans);
}
}
/*
*/