给你长度为N的一个序列,让你将其分成连续的k段,每段的价值为其中数字种类的个数,求最大价值总和。
首先能想到n^2复杂度的dp
设定dp[i][j]表示到位子i,分成j段的最大价值总和。
dp[i][j]=max( dp[i][j],dp[k][j-1]+val(k+1,i) );k为这个数上一次出现的位置
可以用线段树加速转移。
考虑val(k+1,j).
我们遍历到第j个位子的时候,我们显然树上第k个位子表示的是dp[k][j-1]+val(k+1,i),那么考虑第i个数,它会对区间
(pre[a[i]] ,i-1)区间内的树上的位子有所影响。
那么我们遍历到第i个位子的时候,将树上区间(pre[a[i]],i)的值都+1。
这里pre[a[i]]表示的是a[i]这个数上一次出现的位子。
#include<bits/stdc++.h>
using namespace std;
const int maxn = 35555;
int a[maxn<<2],n,k;
int delt[maxn<<2],pre[maxn],pos[maxn];
int dp[maxn][52];
void pushup(int rt)
{
a[rt]=max(a[rt*2],a[rt*2+1]);
}
void pushdown(int rt)
{
delt[rt<<1]+=delt[rt];
delt[rt<<1|1]+=delt[rt];
a[rt<<1]+=delt[rt];
a[rt<<1|1]+=delt[rt];
delt[rt]=0;
}
void update(int rt,int x,int y,int l, int r, int val)
{
if (x<=l&&r<=y)
{
a[rt]+=val;
delt[rt]+=val;
return;
}
pushdown(rt);
int mid=(l+r)>>1;
if (x<=mid)
update(rt<<1,x,y,l,mid,val);
if(y>mid)
update(rt<<1|1,x,y,mid+1,r,val);
pushup(rt);
}
int query(int rt,int x,int y,int l,int r)
{
if (x<=l&&r<=y)
return a[rt];
pushdown(rt);
int mid=(l+r)>>1;
int ans=0;
if(x<=mid)
ans=query(rt<<1,x,y,l,mid);
if(y>mid)
ans=max(ans,query(rt<<1|1,x,y,mid+1,r));
return ans;
}
int main()
{
while(~scanf("%d%d",&n,&k))
{
memset(pos,0,sizeof(pos));
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
pre[i]=pos[x];
pos[x]=i;
}
memset(dp,0,sizeof(dp));
for(int i=1;i<=n;i++)
{
dp[i][1]=dp[i-1][1];
if(!pre[i]) dp[i][1]++;
}
for(int j=2;j<=k;j++)
{
memset(a,0,sizeof(a));
memset(delt,0,sizeof(delt));
for(int i=1;i<=n;i++)
update(1,i,i,1,n,dp[i][j-1]);
for(int i=j;i<=n;i++)
{
update(1,pre[i],i-1,1,n,1);
dp[i][j]=query(1,j-1,i-1,1,n);
}
}
printf("%d\n",dp[n][k]);
}
return 0;
}