Description
Input
第一行给出两个整数\(n\)和\(k\)(\(1 \leqslant n \leqslant 35000, 1 \leqslant k \leqslant min(n,50)\)),分别表示蛋糕数和盒子数,第二行给出\(n\)个数\(a_i\),分别表示每个蛋糕的种类(\(1 \leqslant a_i \leqslant n\))。
Output
输出一个整数,表示能够得到的最大值。
Sample Input
7 2
1 3 3 1 4 4 4
Sample Output
5
Solution
DP。
- 状态:\(dp[j][i]\)表示前\(i\)个蛋糕放到\(j\)个盒子中能得到的最大价值
- 初始:\(dp[1][i] = c[1,i]\),\(c[1,i]\)表示区间\([1,i]\)中不同蛋糕的种类数
- 转移:\(dp[j][i] = max_{1 \leqslant k \leqslant i}\{dp[j-1][k-1]+c[k,i]\}\),第\(j\)个盒子中放第\(k\)到第\(i\)个蛋糕
- 目标:\(dp[m][n]\)
dp数组的大小是\(m \times n\),直接按公式求每个dp值需要\(O(n)\)的时间,这样总的时间复杂度是\(O(kn^2)\),会超时。
考虑用线段树加速求\(dp[j][i]\)时的查询。dp公式中第\(j\)行的值取决于第\(j-1\)行的值,故在计算出第\(j-1\)行的dp值后,建立关于第\(j-1\)行的线段树,从而快速求出第\(j\)行的值。
但是求第\(j\)行的过程并不是对上一行的简单查询,\(c[k,i]\)的值在\(i\)取不同值的时候在变化。
考虑在求\(dp[j][i]\)时先将线段树从\(i-1\)更新到\(i\),再进行查询。
令线段树的最小单元\(a[k]=dp[j-1][k-1]+c[k,i]\),于是转移方程变为\(dp[j][i] = max_{1 \leqslant k \leqslant i}\{a[k]\}\)。注意到对于每个\(a[k]\),\(dp[j-1][k-1]\)是定值,我们只需要考虑由\(i-1\)到\(i\)时\(c[k,i]\)的变化。设第\(i\)个蛋糕的种类是\(y\),前一个\(y\)出现的位置是\(p\),那么\(c[p+1,i+1]\)到\(c[i,i]\)需要加\(1\),前面的不需要更新,这样就通过一次更新从\(i-1\)转移到了\(i\)。
对于每个\(dp[j][i]\),需要次更新和一次查询,用时\(O(logn)\),总时间复杂度为\(O(knlogn)\)。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int N = 35000 + 10;
struct Node
{
int l, r;
int sum, max, maxi;
int lazy;
} tree[4 * N];
int fa[N], a[N];
inline int L(int i) { return i << 1; }
inline int R(int i) { return (i << 1) + 1; }
inline int P(int i) { return i >> 1; }
void pushup(int i)
{
tree[i].sum = tree[L(i)].sum + tree[R(i)].sum;
tree[i].max = max(tree[L(i)].max, tree[R(i)].max);
tree[i].maxi =
tree[L(i)].max >= tree[R(i)].max ? tree[L(i)].maxi : tree[R(i)].maxi;
}
void build(int i, int left, int right)
{
tree[i].l = left; tree[i].r = right;
tree[i].lazy = 0;
if (left == right)
{
tree[i].sum = a[left];
tree[i].max = a[left];
tree[i].maxi = left;
fa[left] = i;
return;
}
int mid = left + (right - left >> 1);
build(L(i), left, mid);
build(R(i), mid + 1, right);
pushup(i);
}
void pushdown(int i)
{
if (!tree[i].lazy) return;
tree[L(i)].sum += (tree[L(i)].r - tree[L(i)].l + 1) * tree[i].lazy;
tree[L(i)].max += tree[i].lazy;
tree[L(i)].lazy += tree[i].lazy;
tree[R(i)].sum += (tree[R(i)].r - tree[R(i)].l + 1) * tree[i].lazy;
tree[R(i)].max += tree[i].lazy;
tree[R(i)].lazy += tree[i].lazy;
tree[i].lazy = 0;
}
void update(int i, int left, int right, int val)
{
if (left <= tree[i].l && right >= tree[i].r)
{
tree[i].sum += (tree[i].r - tree[i].l + 1) * val;
tree[i].max += val;
tree[i].lazy += val;
if (tree[i].l == tree[i].r) a[tree[i].l] += val;
return;
}
pushdown(i);
int mid = tree[i].l + (tree[i].r - tree[i].l >> 1);
if (left <= mid) update(L(i), left, right, val);
if (right > mid) update(R(i), left, right, val);
pushup(i);
}
int query_max(int i, int left, int right)
{
if (left <= tree[i].l && right >= tree[i].r) return tree[i].max;
pushdown(i);
int maxx = -INF;
int mid = tree[i].l + (tree[i].r - tree[i].l >> 1);
if (left <= mid) maxx = max(maxx, query_max(L(i), left, right));
if (right > mid) maxx = max(maxx, query_max(R(i), left, right));
return maxx;
}
int t[N], dp[55][N];
int pre[N], left[N];
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", t + i);
memset(dp, 0, sizeof(dp));
memset(pre, 0, sizeof(pre));
for (int i = 1; i <= n; i++) left[i] = pre[t[i]] + 1, pre[t[i]] = i;
dp[1][0] = 0;
for (int i = 1; i <= n; i++)
dp[1][i] = dp[1][i - 1] + (left[i] == 1);
for (int j = 2; j <= m; j++)
{
for (int i = 1; i <= n; i++) a[i] = dp[j - 1][i - 1];
build(1, 1, n);
for (int i = 1; i <= n; i++)
{
update(1, left[i], i, 1);
dp[j][i] = query_max(1, 1, i);
}
}
printf("%d\n", dp[m][n]);
return 0;
}