题目来源
Description
给定 n n n 个数,将其划分为 k k k 段区间,需要求出每一段区间的价值之和的最大值。
一段区间的价值定义为这段区间内不同数字的个数。
- n ≤ 35000 , k ≤ 50 n\le35000,k\le50 n≤35000,k≤50。
Solution
朴素的DP方程是好想的。
令 f i , j f_{i,j} fi,j 表示将前 j j j 个数划分为 i i i 段价值之和的最大值。
转移方程即为:
f
i
,
j
=
max
p
=
i
−
1
j
−
1
{
f
i
−
1
,
p
+
c
n
t
(
p
+
1
,
j
)
}
f_{i,j}=\max_{p=i-1}^{j-1}\{f_{i-1,p}+cnt(p+1,j)\}
fi,j=p=i−1maxj−1{fi−1,p+cnt(p+1,j)}
其中
c
n
t
(
l
,
r
)
cnt(l,r)
cnt(l,r) 表示区间
[
l
,
r
]
[l,r]
[l,r] 中不同的数的个数。效率为
O
(
n
2
k
)
O(n^2k)
O(n2k)。
若想要进一步优化,因为DP方程已经是 O ( n k ) O(nk) O(nk) 的了,因此可以猜测效率应为 O ( n k ) O(nk) O(nk) 或 O ( n k log n ) O(nk\log n) O(nklogn),前者可能性较小。
可以发现,对于 c n t ( l , r ) cnt(l,r) cnt(l,r) 是可以优化计算过程的,又因为每次需要求出一段区间的最大值,就可以想到通过线段树进行优化。
在每次要计算第 i i i 个“阶段”时,先以 i − 1 i-1 i−1 “阶段”求出的答案建立 1 1 1 棵线段树(下标应从 0 0 0 开始)。
考虑每一个 a j a_j aj 其对哪些 c n t cnt cnt 值会产生贡献。
显然,若设 p r e j pre_j prej 表示满足前一个与 a j a_j aj 相等的数的位置,则会对 p r e j < l ≤ j pre_j<l\le j prej<l≤j 的区间产生贡献。
因此,我们需要扫描 a j a_j aj,每次将线段树中 [ p r e j , j ) [pre_j,j) [prej,j) 的区间加上 1 1 1(因为 c n t cnt cnt 那里的左端点加了 1 1 1 所以线段树的区间需要减去 1 1 1),同时更新 f i , j f_{i,j} fi,j 的值,即查找线段树中 [ i − 1 , j − 1 ] [i-1,j-1] [i−1,j−1] 的最大值(这样就保证 c n t cnt cnt 的右端点是 j j j)。
最后答案即为 f k , n f_{k,n} fk,n,效率 O ( n k log n ) O(nk\log n) O(nklogn)。
Code
#include <bits/stdc++.h>
using namespace std;
int n,k,a[35005],f[55][35005],pre[35005];
struct Node{
int l,r,Max,add;
}tree[140005];
void pushdown(int p){
if (tree[p].add){
tree[p<<1].add+=tree[p].add;
tree[p<<1|1].add+=tree[p].add;
tree[p<<1].Max+=tree[p].add;
tree[p<<1|1].Max+=tree[p].add;
tree[p].add=0;
}
}
void pushup(int p){
tree[p].Max=max(tree[p<<1].Max,tree[p<<1|1].Max);
}
void build(int p,int l,int r,int now){
tree[p].l=l,tree[p].r=r,tree[p].Max=tree[p].add=0;
if (l==r){
tree[p].Max=f[now][l];
tree[p].add=0;
return;
}
int Mid=(l+r)>>1;
build(p<<1,l,Mid,now);
build(p<<1|1,Mid+1,r,now);
pushup(p);
}
void update(int p,int l,int r){
if (l<=tree[p].l&&tree[p].r<=r){
tree[p].Max++;
tree[p].add++;
return;
}
pushdown(p);
int Mid=(tree[p].l+tree[p].r)>>1;
if (l<=Mid) update(p<<1,l,r);
if (r>Mid) update(p<<1|1,l,r);
pushup(p);
}
int query(int p,int l,int r){
if (l<=tree[p].l&&tree[p].r<=r) return tree[p].Max;
pushdown(p);
int Mid=(tree[p].l+tree[p].r)>>1,res=0;
if (l<=Mid) res=query(p<<1,l,r);
if (r>Mid) res=max(res,query(p<<1|1,l,r));
return res;
}
int main(){
scanf("%d%d",&n,&k);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
for (int i=1;i<=k;i++){
build(1,0,n,i-1);
for (int j=1;j<=n;j++) pre[j]=0;
for (int j=1;j<=n;j++){
update(1,pre[a[j]],j-1),f[i][j]=query(1,i-1,j-1);
pre[a[j]]=j;
}
}
printf("%d\n",f[k][n]);
return 0;
}