题目
有一个长为n(n<=3e5)的数列a[],
你可以把这n个数分成k(k<=50)段,
每一段的价值是这一段内不同数字的个数,
求最大价值
思路来源
https://blog.csdn.net/mengxiang000000/article/details/76576435
题解
首先,想到一个区间dp的做法
dp[i][j]表示把分成i段的前j个值的最大价值,
则有dp[i][j]=max(dp[i-1][pos],v(pos+1,j)),pos+1<=j
即分成i-1段,枚举上次的分段的结尾点pos,
v(pos+1,j)为后面这一段不同数字的个数
而这样做的复杂度是O(k*n*n)的
然后考虑线段树优化,当在最后一段加入a[j]这个数的时候,
若v(pos+1,j)内已经包含a[j],a[j]就没有贡献了,
所以a[j]只会对没有出现过a[j]的区间造成1的贡献,
记last[a[j]]为a[j]上一次出现的位置,
则只要令last[a[j]]<=pos<=j-1,即[pos+1,j]内只含一次a[j]这个值,
那么每次先把a[j]的1的贡献加到[pos+1,j]里,再统计[0,j-1]里最大的dp值,
相当于dp[j]=max(dp[0]到dp[last[a[j]]-1]里的最大值,1+(dp[last[a[j]]到dp[j-1]里的最大值))
这里的j是指第j个数,前面可以加上一维i代表分成i段
而这个i是可以用滚动数组滚掉的,每次重建线段树赋上次更新的dp数组值
线段树里存的是上一次分成i-1段的最优dp值,本次更新分成i段的
最终dp[n]就是n个数分成k段的
心得
注意由于第一个数没有左端界,强赋左端界last[a[j]]=0将其写进线段树里的写法
此外令区间加1之后,区间的最大值显然也只增加了1
区间维护最大值,更新的时候只对区间最大值和标记更新即可
代码
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int maxn=35010;
int n,k;
int a[maxn];
int dp[maxn],last[maxn];
int dat[maxn*4],cov[maxn*4];
void pushup(int p)
{
dat[p]=max(dat[p<<1],dat[p<<1|1]);
}
void pushdown(int p,int l,int r)
{
if(cov[p])
{
dat[p<<1]+=cov[p];
dat[p<<1|1]+=cov[p];
cov[p<<1]+=cov[p];
cov[p<<1|1]+=cov[p];
cov[p]=0;
}
}
void build(int p,int l,int r)
{
cov[p]=0;
if(l==r)
{
dat[p]=dp[l];
return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
pushup(p);
}
void update(int p,int l,int r,int ql,int qr,int v)
{
if(ql<=l&&r<=qr)
{
dat[p]+=v;
cov[p]+=v;
return;
}
pushdown(p,l,r);
int mid=(l+r)>>1;
if(ql<=mid)update(p<<1,l,mid,ql,qr,v);
if(qr>=mid+1)update(p<<1|1,mid+1,r,ql,qr,v);
pushup(p);
}
int query(int p,int l,int r,int ql,int qr)
{
if(ql<=l&&r<=qr)return dat[p];
int res=0,mid=(l+r)>>1;
pushdown(p,l,r);
if(ql<=mid)res=max(res,query(p<<1,l,mid,ql,qr));
if(qr>=mid+1)res=max(res,query(p<<1|1,mid+1,r,ql,qr));
return res;
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1;i<=n;++i)
scanf("%d",&a[i]);
for(int i=1;i<=k;++i)
{
build(1,0,n);
for(int j=1;j<=n;++j)
dp[j]=last[a[j]]=0;
for(int j=1;j<=n;++j)
{
update(1,0,n,last[a[j]],j-1,1);
last[a[j]]=j;
dp[j]=query(1,0,n,0,j-1);
}
}
printf("%d\n",dp[n]);
return 0;
}