题目链接
给一个序列,要求把序列划分成k段,每一段里的权值是这一段里不同数字的个数。求如何划分使得k个区间的权值和最大。
显然这是一题用线段树维护的二维dp问题
给一个序列,要求把序列划分成k段,每一段里的权值是这一段里不同数字的个数。求如何划分使得k个区间的权值和最大。
我们容易想到dp[i][j]代表前j个分为i段时的最大值
dp[i][j] = dp[i-1][k] + size(k+1 , j) ( i-1<=k<j )
//size(a,b)表示a到b这个区间里有多少个不同数字
显然这是一题用线段树维护的二维dp问题
这里的每个叶节点,首先赋初值为dp[i-1][j]即到前j个数分成i-1段的价值,然后进行维护,即加上size(k+1,j),即每个结点代表的意思为dp[i-1][k] + size(k+1 , j) ( i-1<=k<j )。
接下来的问题就是怎么维护这个size(k+1 , j) ?
当枚举到num[j]的时候,把num[j]这个数字最近一次出现的位置记录一下p,然后把dp[i-1][p]到dp[i-1][j-1]加上1,然后dp[i][j] = max(dp[i-1][k]) ( i-1<=k < j) 。这个意思就是说p~j-1这个区间中没有出现过num[j]这个元素,所以要对他们的size(k+1,j)加1。
[ 这个地方可以算是求不同颜色的一种方法(get到新知识) ],于是我们在查询和加1的操作都可以通过线段树来维护.
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstring>
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define maxn 40000
int sum[maxn<<2];
int lazy[maxn<<2];
int dp[60][maxn];
int pos[maxn];
int num[maxn];
void pushup(int rt)
{
sum[rt]=max(sum[rt<<1],sum[rt<<1|1]);
}
void pushdown(int rt)
{
if(lazy[rt]!=0)
{
sum[rt<<1]+=lazy[rt];
sum[rt<<1|1]+=lazy[rt];
lazy[rt<<1]+=lazy[rt];
lazy[rt<<1|1]+=lazy[rt];
}
lazy[rt]=0;
}
void build(int l,int r,int rt,int x)
{
lazy[rt]=0;
if(l==r)
{
sum[rt]=dp[x][l];
return ;
}
int m=(l+r)>>1;
build(lson,x);
build(rson,x);
pushup(rt);
}
void update(int L,int R,int l,int r,int rt)
{
if(L<=l&&r<=R)
{
sum[rt]++;
lazy[rt]++;
return ;
}
int m=(l+r)>>1;
pushdown(rt);
if(L<=m)update(L,R,lson);
if(R>m)update(L,R,rson);
pushup(rt);
}
int query(int L,int R,int l,int r,int rt)
{
if(L<=l&&r<=R)
{
return sum[rt];
}
int m=(l+r)>>1;
pushdown(rt);
int res=0;
if(L<=m)res=max(res,query(L,R,lson));
if(R>m)res=max(res,query(L,R,rson)) ;
return res;
}
int main()
{
int n,m;
memset(dp,0,sizeof dp);
memset(pos,0,sizeof pos);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&num[i]);
for(int i=1;i<=n;i++)
if(!pos[num[i]])
{
dp[1][i]=dp[1][i-1]+1;
pos[num[i]]=i;
}
else
dp[1][i]=dp[1][i-1];
for(int i=2;i<=m;i++)//分成i段
{
build(1,n,1,i-1);
memset(pos,0,sizeof pos);
for(int j=i;j<=n;j++)//即前j个点分为i段
// dp[i][j] = dp[i-1][k] + size(k+1 , j) ( 0<=k<j )
{
int t=pos[num[j]];
if(!t)
t=1;
update(t,j-1,1,n,1);
int res=query(i-1,j-1,1,n,1);
//因为之前那个状态分了i-1段 所以至少要取i-1个 但最多取j-1 剩下的自成一段
dp[i][j]=res;
//cout<<res<<endl;
pos[num[j]]=j;
}
}
printf("%d\n",dp[m][n]);
}