今天听fy大佬讲了一道极好的DP题,不同于其他很多DP,极简的转移方式与数据结构优化,会使这些题更接近于一个数据结构的暴力题目。但这道题的绝妙之处在于不需要用任何数据结构,优化巧妙,将重点更加放在数组本身的性质而非如何用数据结构暴力维护降复杂度。
原题:传送门
题目简介:一个长度为n的数列,将其划分为任意个区间,使代价和最小。这里定义代价指区间内数的个数的平方。输出代价最小值。
经过简单的分析,我们得出状态转移方程:dp[i]=min(dp[i],dp[j]+w[i][j]^2),dp[i]指位置i前的数的代价最小值。w[i][j]指区间[i,j]内数字种类数。于是我们获得了一个复杂度O(n^2)的算法,能过50%左右的数据。
暴力DP代码如下:
#include <bits/stdc++.h>
using namespace std;
const int SIZE=40005;
const int INF=0x3f3f3f3f;
int n,m,cnt=0;
int a[SIZE],col[SIZE],dp[SIZE];
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
memset(dp,INF,sizeof(dp));
memset(col,0,sizeof(col));
dp[0]=0;
for (int i=1;i<=n;i++){
cnt=0;
memset(col,0,sizeof(col));
for (int j=i;j>=1;j--){
if (!col[a[j]]){
col[a[j]]=1;
cnt++;
}
dp[i]=min(dp[i],dp[j-1]+cnt*cnt);
}
}
printf("%d",dp[n]);
return 0;
}
实测:在洛谷上能得到50分的好成绩。
考场上应该可以了,但我们更应追求正解。
其实离正解已经不远了,我们分析一下:因为可以划无数段,可以人为将每一个数单独并称一组,可以得到代价和为1^2*n=n,即对于每组数,答案一定小于n。很显然的是,如果一个区间内的数种类超过了sqrt(n),则代价将超过n,但上面得出答案一定小于等于n,所以一个区间中数种类一定小于等于sqrt(n)。于是可以弄一个数组pos[i]表示第i位以前种类数为i的最远位置,于是得到一个全新的dp方程:dp[i]=min(dp[i],dp[pos[j]-1]+j*j),因为我们规定j<sqrt(n),所以只要枚举sqrt(n)个j,复杂度为O(nsqrt(n))。
至于如何快速更新pos[]数组,可以用“四保一”技巧,即使用四个数组相互维护来维护目标数组,还是很有用的。具体地,pre[i]表示第i位数字上一次出现的位置,next[i]表示第i位数字下一次出现的位置,这两个数组之间第i位数字只出现一次,即:区间( pre[i],next[i] )内数字a[i]只出现1次。last[i]表示数字i最后一次出现位置,可以判断下界。cnt[i]表示当前pos[i]出现的次数。pre[i]<pos[i],表示在[pos[j],i-1]中第i位没有出现过,则cnt[i]++,cnt[j]>j时,pos[j]应该右移以保证其出现过,next同理,last来判断边界。
#include <bits/stdc++.h>
using namespace std;
const int SIZE=40005;
const int INF=0x3f3f3f3f;
int n,m;
int a[SIZE],dp[SIZE],pre[SIZE],next[SIZE],last[SIZE],cnt[SIZE],pos[SIZE];
void clear()
{
for (int i=1;i<=n;i++){
pre[i]=last[a[i]];
next[last[a[i]]]=i;
last[a[i]]=i;
next[i]=n+1;
}
for (int i=1;i<=sqrt(n);i++)
pos[i]=1;
memset(dp,INF,sizeof(dp));
dp[0]=0;
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]);
clear();
for (int i=1;i<=n;i++)
for (int j=1;j<=(int)sqrt(n);j++){
if (pre[i]<pos[j]) cnt[j]++;
if (cnt[j]>j){
while (next[pos[j]]<i) pos[j]++;
cnt[j]--; pos[j]++;
}
dp[i]=min(dp[i],dp[pos[j]-1]+j*j);
}
printf("%d",dp[n]);
return 0;
}