题意大概是给一个序列,每个数表示一种颜色,每次可以取走连续的一段,花费是取走的这段中颜色数量的平方,问取完这个序列的最小花费是多少。
首先,容易想到一个N^2的DP,dp[i]=min(dp[j-1]+v(j,i)^2),其中j=1..i,v(j,i)表示从j到i的序列中不同的颜色数,N是50000所以肯定要T...因为转移实际上是按不同的颜色数来转移的,所以可以用一个数组minp[k]来表示区间右端点为i时,出现k种颜色的序列的中,左端点的dp值最小的位置,也就是在确定上式中v(j,i)=k的时候,找最小的dp[j-1]。那么每次dp的时候,先维护一下minp[],然后枚举颜色数转移就行了,注意到最坏的情况也可以n个位置各自合并达到代价n,所以颜色数最多枚举到sqrt(n)就行了...维护minp[]的话,每次dp到i时,找一下a[i]左边最近的位置lp,如果minp[j]在lp,i之间,那么用minp[j]去尝试更新minp[j+1]的值,并且无论是否成功更新minp[j+1]都把minp[j]清零0,从1到sqrt(n)枚举完j后,手动处理一下边界minp[1],就是与i颜色相同并且位置连续的最左面的位置con[i],这个也可以预处理下来。
/*=============================================================================
# Author:Erich
# FileName:
=============================================================================*/
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <queue>
#include <stack>
#define lson id<<1,l,m
#define rson id<<1|1,m+1,r
using namespace std;
typedef long long ll;
const int inf=0x3f3f3f3f;
const ll INF=1ll<<60;
const double PI=acos(-1.0);
int n,m;
const int maxn=100050;
ll dp[maxn];
int minp[maxn];
int a[maxn],b[maxn];
int last[maxn];
int c[maxn];
int con[maxn];
int main()
{
// freopen("in.txt","r",stdin);
while(~scanf("%d",&n))
{
for (int i=1; i<=n; i++) scanf("%d",&a[i]);
for (int i=0; i<n; i++)
b[i]=a[i+1];
sort(b,b+n);
m=unique(b,b+n)-b;
map<int,int>mp;
int sp=0;
for (int i=0; i<n; i++)
mp[b[i]]=sp++;
for (int i=1; i<=n; i++)
a[i]=mp[a[i]];
memset(dp,0,sizeof dp);
memset(minp,0,sizeof minp);
memset(last,-1,sizeof last);
memset(c,0,sizeof c);
for (int i=1; i<=n; i++)
{
if (c[a[i]]) last[i]=c[a[i]];
c[a[i]]=i;
}
int p1=1,p2=1;
con[1]=1;
while(p1<=n)
{
while(a[p2+1]==a[p1]) p2++,con[p2]=p1;
p1=p2+1;
p2=p1;
con[p2]=p1;
}
m=min(m,(int)sqrt((double)n));
// m=min(m,5);
for (int i=1; i<=n; i++)
{
int lp=last[i];
for (int j=m; j>=1; j--)
{
if (minp[j])
{
if (minp[j]>lp)
{
if (minp[j+1]==0 || dp[minp[j+1]]>dp[minp[j]]) minp[j+1]=minp[j];
minp[j]=0;
}
}
}
minp[1]=con[i];
dp[i]=dp[i-1]+1;
for (int j=m; j>=1; j--)
{
if (minp[j])
{
int p=minp[j];
dp[i]=min(dp[i],dp[p-1]+(ll)j*j);
}
}
}
printf("%I64d\n",dp[n]);
}
return 0;
}