wqs 二分(Alien Trick)优化 DP 学习笔记

本文介绍了WQS二分算法,一种用于优化动态规划问题的方法,以解决具有凸包性质的成本函数。通过实例分析了如何将P4983忘情问题中的序列分割转化为数学模型,并展示了如何使用WQS二分技巧将其复杂度降低到O(npolylog)。
摘要由CSDN通过智能技术生成

引入

wqs 二分是由王钦石提出的一类用来优化 dp 的算法,在国外被称为 「Alien Trick」(源于 IOI 2016 的题目 aliens)。

它可以用来解决如下问题:从 n n n 个物品中选 x x x 个,设其代价为 f ( x ) f(x) f(x),若 f ( x ) f(x) f(x) 构成一个凸包,那么就可以用 wqs 二分来解决。

例如这个题:P4983 忘情

题目大意:

一段序列 a a a,其权值为

( ∑ i = 1 n a i a i ˉ + a i ˉ ) 2 a i ˉ 2 \frac{(\sum\limits_{i=1}^n a_i\bar{a_i}+\bar{a_i})^2}{\bar{a_i}^2} aiˉ2(i=1naiaiˉ+aiˉ)2

现在给出长度为 n n n 的序列,将这个序列分成长度为 m m m 的序列,求这 m m m 个序列权值和的最小值。

首先,化简式子,得原式为

( ∑ i = 1 n ( a i + 1 ) a i ˉ a i ˉ ) 2 = ( ∑ i = 1 n a i + 1 ) 2 = ( S i + 1 ) 2 \begin{aligned} (\frac{\sum\limits_{i=1}^n (a_i+1)\bar{a_i}}{\bar{a_i}})^2&=(\sum\limits_{i=1}^n a_i+1)^2\\ &=(S_i+1)^2 \end{aligned} (aiˉi=1n(ai+1)aiˉ)2=(i=1nai+1)2=(Si+1)2

其中 S i = ∑ i = 1 n a i S_i=\sum_{i=1}^na_i Si=i=1nai

f i , j f_{i,j} fi,j 表示分成了 i i i 个序列,考虑到第 j j j 个数的权值和最小值。于是有状态转移方程

f i , j = min ⁡ k = 0 j − 1 ( f i − 1 , k + ( S j − S k ) 2 ) f_{i,j}=\min\limits_{k=0}^{j-1} (f_{i-1,k}+ (S_j-S_k)^2) fi,j=k=0minj1(fi1,k+(SjSk)2)

可以一眼看出这可以斜率优化或者李超树维护,但这样做依旧是 O ( n 2 polylog ) O(n^2 \text{polylog}) O(n2polylog) 的,怎么办呢?这时就要用到 wqs 二分。

wqs 二分优化 dp

首先记 g i = f i , n g_i=f_{i,n} gi=fi,n 表示分成 i i i 个序列的代价最小值。打表可知 g i g_i gi 是下凸的,那么我们可以二分斜率,每次判断切点是否是 i = m i=m i=m,如果不是就继续二分。

怎么做呢?首先二分出来斜率 k k k,那么在每个点 ( i , g i ) (i,g_i) (i,gi) 上有

g i = k × i + b i g_i=k\times i+b_i gi=k×i+bi

显然截距 b i b_i bi 最小的那个 i i i 就是切点。原式变形得

b i = g i − k × i b_i=g_i-k\times i bi=gik×i

我们发现在 f i , j ← f i − 1 , l f_{i,j}\leftarrow f_{i-1,l} fi,jfi1,l 时,可以将其改成 f i − 1 , j ′ ← f i , l ′ − k f_{i-1,j}'\leftarrow f_{i,l}'-k fi1,jfi,lk,即每转移一次 i i i 就减一次 k k k,不难发现 f i , l ′ = f i , l − k × i f_{i,l}'=f_{i,l}-k\times i fi,l=fi,lk×i,最终 f i , n ′ f_{i,n}' fi,n 最小的那个 i i i 就是切点。但这样还不够,我们发现最终只需要知道最小的终状态转移了几次,并没有规定一定要转移几次,所以可以踢掉 i i i 这一维,相当于做「将序列分成若干个序列,使其代价和最小的方案是分成多少个」这个问题。

于是就可以直接计算 g g g,方程为

g i = min ⁡ j = 0 i − 1 ( g j + ( S i − S j ) 2 ) − k g_{i}=\min\limits_{j=0}^{i-1} (g_{j}+ (S_i-S_j)^2)-k gi=j=0mini1(gj+(SiSj)2)k

其中 k k k 为二分的斜率。解决一次上面的问题是 O ( n ) O(n) O(n) O ( n log ⁡ n ) O(n\log n) O(nlogn) 的,每次转移另外记录转移次数 c n t i = c n t j + 1 cnt_i=cnt_j+1 cnti=cntj+1 j j j g i g_i gi 的最优决策点。然后每次二分判断 c n t n cnt_n cntn m m m 的关系,直到二分结束,就算出了最小截距,但不是最终答案。最终答案还要将其加上 k × m k\times m k×m 得到 g m g_m gm

另外注意对于一个斜率可能存在截距相同的点,dp 时尽可能选择靠右的点,二分时令横坐标不小于 m m m 的点为合法点,于是右半部分是合法区间,二分找到最靠左的合法点就是答案。

于是,我们就得到了一个 O ( n polylog ) O(n\text{polylog}) O(npolylog) 的算法。

这个题目的代码:

/*
    Program: P4983.cpp
    Author: 1l6suj7
    DateTime: 2024-01-22 14:52:56
    Description:
*/

#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define lp(i, j, n) for(int i = j; i <= n; ++i)
#define dlp(i, n, j) for(int i = n; i >= j; --i)
#define mst(n, v) memset(n, v, sizeof(n))
#define mcy(n, v) memcpy(n, v, sizeof(v))
#define INF 1e18
#define MAX4 0x3f3f3f3f
#define MAX8 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define pll pair<ll, ll>
#define co(x) cerr << (x) << ' '
#define cod(x) cerr << (x) << endl
#define fi first
#define se second
#define eps 1e-8
#define lc(x) ((x) << 1)
#define rc(x) ((x) << 1 ^ 1)
#define pb(x) emplace_back(x)

using namespace std;

const int N = 100010;

int n, m, a[N], cnt[N];
ll q[N], f[N], s[N]; int l, r;

ll gety(int i, int k) { return f[k] + (s[i] - s[k] + 1) * (s[i] - s[k] + 1); }

long double getk(int x1, int x2) { return (long double)(f[x1] + s[x1] * (s[x1] - 2) - f[x2] - s[x2] * (s[x2] - 2)) / (s[x1] - s[x2]); }

int judge(ll mid) {		// 斜率优化
    mst(f, 0), mst(cnt, 0), l = 1, r = 0;
    q[++r] = 0;
    lp(i, 1, n) {
        while(l < r && gety(i, q[l + 1]) <= gety(i, q[l])) ++l;		// 前面两个点截距相等也出队,保证选最右边的点
        f[i] = gety(i, q[l]) - mid, cnt[i] = cnt[q[l]] + 1;
        while(l < r && getk(i, q[r]) < getk(q[r], q[r - 1])) --r;
        q[++r] = i;
    }
    // cod(cnt[n]);
    return cnt[n];
}

signed main() {
    // freopen("P4983.in", "r", stdin);
    // freopen("P4983.out", "w", stdout);
#ifndef READ
    ios::sync_with_stdio(false);
    cin.tie(0);
#endif
    cin >> n >> m;
    lp(i, 1, n) cin >> a[i], s[i] = s[i - 1] + a[i];
    ll l = -1e15, r = 1e15;
    while(l < r) {		// 二分斜率
        ll mid = l + r >> 1;
        if(judge(mid) >= m) r = mid;
        else l = mid + 1;
    }
    judge(l);
    cout << f[n] + l * m << endl;
    return 0;
}

其他题目

[COCI2018-2019#4] Akvizna

[IOI2016] aliens

[CF739E] Gosha is hunting

[CSES] Programmers and Artists

等我写完了再来写题解

  • 24
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值