题意:
有
n
n
n个数字
a
1
,
2...
n
a_{1,2...n}
a1,2...n,要分成k段,每段的价值为段内满足
[
a
i
=
=
a
j
]
a
n
d
[
i
=
=
j
]
[a_i==a_j] and[i==j]
[ai==aj]and[i==j]的
(
i
,
j
)
(i,j)
(i,j)对数。
(
n
,
a
i
≤
1
e
5
,
k
≤
m
i
n
(
n
,
20
)
)
(n,a_i\le 1e5,k\le min(n,20))
(n,ai≤1e5,k≤min(n,20))
解题思路:
有一个很显然的
O
(
n
2
k
)
O(n^2k)
O(n2k)的做法,
d
p
(
i
,
j
)
dp(i,j)
dp(i,j)表示把
a
1
,
2
,
.
.
j
a_{1,2,..j}
a1,2,..j分成
i
i
i段的最小价值,转移为
d
p
(
i
,
j
)
=
m
i
n
(
d
p
(
i
−
1
,
j
′
)
+
c
o
s
t
(
j
′
,
j
)
)
[
j
′
<
j
]
dp(i,j)=min(dp(i-1,j')+cost(j',j))\quad[j'<j]
dp(i,j)=min(dp(i−1,j′)+cost(j′,j))[j′<j]
这里
c
o
s
t
(
j
′
,
j
)
cost(j',j)
cost(j′,j)表示
(
j
′
,
j
]
(j',j]
(j′,j]这一段的价值。
考虑优化,往单调性方面去考虑:每个转移点的位置是否满足单调性?设
p
j
p_j
pj为
d
p
(
i
,
j
)
dp(i,j)
dp(i,j)的最左端的最优转移点,对于
(
j
′
<
j
)
(j'<j)
(j′<j)是否总有
p
j
′
≤
p
j
p_{j'}\le p_j
pj′≤pj?
答案是Yes。现在用反证法证明:
证明思路来自:https://codeforces.com/blog/entry/55046
假设存在
j
′
<
j
j'<j
j′<j,且
d
p
(
i
−
1
,
x
)
+
c
o
s
t
(
x
,
j
)
<
d
p
(
i
−
1
,
p
j
′
)
+
c
o
s
t
(
p
j
′
,
j
)
dp(i-1,x)+cost(x,j)<dp(i-1,p_{j'})+cost(p_{j'},j)
dp(i−1,x)+cost(x,j)<dp(i−1,pj′)+cost(pj′,j)满足
x
<
p
j
′
x<p_{j'}
x<pj′。
画图出来是这样:
那么显然的有:
c
o
s
t
(
x
,
j
)
−
c
o
s
t
(
p
j
′
,
j
)
>
c
o
s
t
(
x
,
j
′
)
−
c
o
s
t
(
p
j
′
,
j
′
)
cost(x,j)-cost(p_{j'},j)>cost(x,j')-cost(p_{j'},j')
cost(x,j)−cost(pj′,j)>cost(x,j′)−cost(pj′,j′)
由假设的条件我们得到
d
p
(
i
−
1
,
p
j
′
)
−
d
p
(
i
−
1
,
x
)
>
c
o
s
t
(
x
,
j
)
−
c
o
s
t
(
p
j
′
,
j
)
>
c
o
s
t
(
x
,
j
′
)
−
c
o
s
t
(
p
j
′
,
j
′
)
dp(i-1,p_{j'})-dp(i-1,x)> cost(x,j)-cost(p_{j'},j)>cost(x,j')-cost(p_{j'},j')
dp(i−1,pj′)−dp(i−1,x)>cost(x,j)−cost(pj′,j)>cost(x,j′)−cost(pj′,j′)
移项得到
d
p
(
i
−
1
,
p
j
′
)
+
c
o
s
t
(
p
j
′
,
j
′
)
>
d
p
(
i
−
1
,
x
)
+
c
o
s
t
(
x
,
j
′
)
dp(i-1,p_{j'})+cost(p_{j'},j')>dp(i-1,x)+cost(x,j')
dp(i−1,pj′)+cost(pj′,j′)>dp(i−1,x)+cost(x,j′)
这与
p
j
′
p_{j'}
pj′是
j
′
j'
j′的最左端最优转移点矛盾
证毕。
那么知道了它满足单调性之后,就可以采用类似整体二分的分治去转移了,因为每一层的搜索区间总长度是
O
(
n
)
O(n)
O(n)的,一共有
l
o
g
log
log层,所以总复杂度
O
(
k
n
l
o
g
n
)
O(knlogn)
O(knlogn)
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5 + 50;
int cnt[maxn], a[maxn];
ll ans;
void add(int x){
ans += cnt[x]; cnt[x]++;
}
void del(int x){
ans -= (cnt[x]-1); cnt[x]--;
}
int lp , rp;
void go(int l, int r){
while(lp > l) add(a[--lp]);
while(rp < r) add(a[++rp]);
while(lp < l) del(a[lp++]);
while(rp > r) del(a[rp--]);
}
ll dp[21][maxn];
int n, k, t;
void sol(int l, int r, int L, int R, ll *pre, ll *cur){
if(l > r) return;
int mid = (l+r)>>1, p;
for(int i = L; i <= min(R, mid-1); ++i){
go(i+1, mid);
if(cur[mid] > pre[i] + ans) cur[mid] = pre[i] + ans, p = i;
}
sol(l, mid-1, L, p, pre, cur); sol(mid+1, r, p, R, pre, cur);
}
int main()
{
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
memset(dp, 0x3f, sizeof dp); dp[0][0] = 0;
lp = rp = 1; add(a[1]);
for(t = 1; t <= k; ++t){
dp[t][0] = 0;
sol(1, n, 0, n-1, dp[t-1], dp[t]);
}
cout<<dp[k][n]<<endl;
}
/*
20 3
1 3 2 0 1 0 2 2 2 0 1 1 1 3 1 3 3 2 3 0
*/