题意
给定一个序列 {a1,a2,⋯,an},要把它分成恰好 k 个连续子序列。
每个连续子序列的费用是其中相同元素的对数,求所有划分中的费用之和的最小值。
2≤n≤105,2≤k≤min(n,20),1≤ai≤n
题解
显然具有决策单调性,可以用四边形不等式优化DP
唯一的问题是怎么快速计算w(l,r)的贡献
用分治的写法。暴力从fa区间的左右端点移动到当前左右端点计算贡献即可。
考虑分治的过程,总共logn层,每层决策区间长度和是n,所以只移动左端点复杂度nlogn
每层的右端点移动也是O(n)的,
所以计算贡献总复杂度是nlogn
总结:
这道题利用了分治的性质证明复杂度,要仔细的推一推!
每层区间做完之后要还原成上一个区间的样子
#include<bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define rvc(i,S) for(register int i = 0 ; i < (int)S.size() ; i++)
#define rvcd(i,S) for(register int i = ((int)S.size()) - 1 ; i >= 0 ; i--)
#define fore(i,x)for (register int i = head[x] ; i ; i = e[i].next)
#define forup(i,l,r) for (register int i = l ; i <= r ; i += lowbit(i))
#define fordown(i,id) for (register int i = id ; i ; i -= lowbit(i))
#define pb push_back
#define prev prev_
#define stack stack_
#define mp make_pair
#define fi first
#define se second
#define lowbit(x) (x&(-x))
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<ll,ll> pr;
const ll inf = 2e18;
const int N = 3e6 + 10;
const int maxn = 100020;
const ll mod = 1e9 + 7;
int n,k;
int num[maxn],a[maxn];
ll f[22][maxn],cur;
inline void add(int x){
cur += num[x];
num[x]++;
}
inline void dec(int x){
num[x]--;
cur -= num[x];
}
void solve(int k,int l,int r,int pl,int pr,int lastl,int lastr,ll lastw){
if ( l > r ) return;
int mid = (l + r) >> 1,id = pl;
cur = lastw;
if ( mid > lastr ){
rep(i,lastr + 1,mid) add(a[i]);
}
else{
repd(i,lastr,mid + 1) dec(a[i]);
}
if ( id > lastl ){
rep(i,lastl,id - 1) dec(a[i]);
}
else{
repd(i,lastl - 1,id) add(a[i]);
}
f[k][mid] = f[k - 1][id - 1] + cur;
rep(i,pl + 1,pr){
dec(a[i - 1]);
if ( f[k - 1][i - 1] + cur < f[k][mid] ) f[k][mid] = f[k - 1][i - 1] + cur , id = i;
}
solve(k,l,mid - 1,pl,id,pr,mid,cur);
solve(k,mid + 1,r,id,pr,pr,mid,cur);
//恢复成上一层
if ( mid > lastr ){
repd(i,mid,lastr + 1) dec(a[i]);
}
else{
rep(i,mid + 1,lastr) add(a[i]);
}
if ( pr > lastl ){
repd(i,pr - 1,lastl) add(a[i]);
}
else{
rep(i,pr,lastl - 1) dec(a[i]);
}
cur = lastw;
}
int main(){
scanf("%d %d",&n,&k);
rep(i,1,n) scanf("%d",&a[i]);
rep(i,1,n){
add(a[i]);
f[1][i] = cur;
}
rep(i,2,k){
rep(i,1,n) num[i] = 0;
cur = 0;
solve(i,1,n,1,n,1,0,0);
}
cout<<f[k][n]<<endl;
}