文章目录
原理
什么是二分栈
个人理解的二分栈,应该是一种类似刷表法的算法,对于每个点i,先更新自己的答案,再弹掉所有转移不如i的区间,最后在后面的
[
i
+
1
,
n
]
[i+1,n]
[i+1,n] 的区间中去二分查找可以更新的,以自己为最优转移的最靠左的点j,加入区间
[
j
,
n
]
[j,n]
[j,n] ,并入栈。
而与一般的单调队列的优化相比,就差不多是填表法和刷表法的区别
图解
先有一个区间
为了方便,假设前面的点都是由0号点转移的,即加入区间
[
1
,
n
]
[1,n]
[1,n]
然后枚举到第1个点,从0更新,然后发现后面的点都可以由1号点更新(比原来0号点更优),就加入区间
[
2
,
n
]
[2,n]
[2,n]
枚举到i=2的时候发现更新的区间不再是i=0时加入的区间了,应该是i=1时候加入的区间,就把第一个区间弹掉,更新2号点的答案,再加入二号点的对应区间
[
j
1
,
n
]
[j_1,n]
[j1,n] (二分查找)
依次向后枚举
这个时候,发现当i=5时,更新点j_3的方案中,5号点比4号点好,就弹出4号点加入的区间(因为满足决策单调,所以只需要判断左端点就知道整个区间以哪个点为决策点更优),加入5号二分出来的区间
[
j
4
,
n
]
[j_4,n]
[j4,n]
然后按以上规则依次转移即可
什么时候使用二分栈
首先,因为我们加入的区间一定是
[
j
,
n
]
[j,n]
[j,n],所以一定题目中要求的是决策点之间至少需要xx距离,而如果是至多,就不方便使用二分栈了,推荐用单调队列优化。
其次,我们的栈的更新是一个连续的区间,(即,只要左端点依据当前点最优,则区间内所有元素都如此),所以需要满足决策的单调性,(就是不能出现下面这种情况)
其中,
[
k
1
,
k
2
]
[k_1,k_2]
[k1,k2] 是以1号点作为最优转移点的区间
时间复杂度
因为每个点只可能出栈入栈一个区间,且每次查找区间的时候使用二分,故时间复杂度为 O ( n log n ) O(n\log{n}) O(nlogn)
二分栈优劣
优点
- 时间复杂度很优秀,将一般不带优化 O ( n 2 ) O(n^2) O(n2) 的时间优化到 O ( n log n ) O(n\log{n}) O(nlogn)
- 思维难度较低,不需要推很多式子,只需要证明决策单调即可
缺点
- 容易被一些强迫写 O ( n ) O(n) O(n) 算法的题卡掉
例题
[BZOJ1010][HNOI2008]玩具装箱
题目大意
给你N个玩具,要把每个玩具都打包,打包第i个到第j个玩具的长度是
x
=
j
−
i
+
∑
k
=
i
j
C
k
x = j - i + \sum_{k = i} ^ {j} C_k
x=j−i+∑k=ijCk 价格是
(
x
−
L
)
2
(x - L) ^ 2
(x−L)2
其实就是价格 c o s t ( i , j ) = ( j − i + ∑ k = i j C k − L ) 2 cost(i,j) = (j - i + \sum_{k = i} ^ {j} C_k - L)^2 cost(i,j)=(j−i+∑k=ijCk−L)2
其中,L是个常量
输入
第一行输入两个整数N,L.接下来N行输入Ci
输出
输出最小总价格
样例
5 4
3
4
2
1
4
限制
1 < = N < = 50000 , 1 < = L , C i < = 1 0 7 1 <= N <= 50000, 1 <= L, Ci <= 10 ^ 7 1<=N<=50000,1<=L,Ci<=107
分析
此题的转移式很好得:
d
p
[
i
]
=
min
{
d
p
[
j
]
+
c
o
s
t
(
j
+
1
,
i
)
,
0
<
=
j
<
i
,
1
<
=
i
<
=
n
}
dp[i] = \min\{dp[j] + cost(j+1,i),0 <= j < i,1 <= i <= n\}
dp[i]=min{dp[j]+cost(j+1,i),0<=j<i,1<=i<=n}
但是这样做就是
O
(
n
2
)
O(n^2)
O(n2) 的复杂度,考虑二分栈优化。
二分栈
因为显然
c
o
s
t
(
i
,
k
)
>
=
0
cost(i,k) >= 0
cost(i,k)>=0 还是个二次函数,如图:
红色线一段就是用
i
2
i_2
i2 这个点来转移更优的区间
所以可以用二分栈来做
代码
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <queue>
#define re register
#define digit(x) (x >= '0' && x <= '9')
#define gc getchar
typedef long long LL;
using namespace std;
LL read()
{
LL x = 0, f = 1; char c = gc();
while (!digit(c)){if (c == '-') f = -f; c = gc();}
while (digit(c)) x = (x << 3) + (x << 1) + c - '0', c = gc();
return x * f;
}
const int N = 50005;
int n;
LL L;
LL a[N], sum[N];
LL d[N];
struct Node
{
int ind, l;
Node(){}
Node(int I, int L){ind = I, l = L;}
}q[N];
bool Pan(int j1, int j2, int i)
{
LL x = i - j1 - 1 + sum[i] - sum[j1];
LL y1 = d[j1] + (x - L) * (x - L);
x = i - j2 - 1 + sum[i] - sum[j2];
LL y2 = d[j2] + (x - L) * (x - L);
return y2 <= y1;
}
int main()
{
n = read(); L = read();
for (re int i = 1; i <= n; i++)
a[i] = read(),
sum[i] = a[i] + sum[i - 1];
int l = 1, r = 0;
q[++r] = Node(0, 0);
for (re int i = 1; i <= n; i++)
{
while (l < r && q[l + 1].l <= i) l++;
LL x = i - q[l].ind - 1 + sum[i] - sum[q[l].ind];
d[i] = d[q[l].ind] + (x - L) * (x - L);
while (l <= r && Pan(q[r].ind, i, q[r].l)) r--; //弹掉前面的决策不如i的区间
int l1 = i, r1 = n + 1;
if (l <= r) l1 = q[r].l;
while (l1 + 1 < r1)
{
int mid = (l1 + r1) >> 1;
if(Pan(q[r].ind, i, mid))
r1 = mid;
else l1 = mid;
}
if (r1 == n + 1) continue;
q[++r] = Node(i, r1);
}
printf("%lld\n", d[n]);
return 0;
}
CSP-S 2019 Day2T2 划分(88pts)
题目大意
将一段长度为n的序列划分成若干个区间,使得区间sum递增,且所有区间的平方和最小
其实大家差不多都知道吧
分析
其实就是先二分一下对于当前点i能选的下一个最左边的点的位置k,即:
其中,
s
u
m
(
i
,
k
)
>
=
s
u
m
(
l
a
i
,
i
)
sum(i,k) >= sum(la_i,i)
sum(i,k)>=sum(lai,i)
那么区间 [ k , n ] [k,n] [k,n] 就是当前以i为最优决策点的区间(因为选到k和它后面的点的时候从i转移肯定比从i之前的点转移更优)。
代码
#include <cstdio>
#include <algorithm>
#include <vector>
#include <queue>
#include <cstring>
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
#define gc getchar
#define re register
#define digit(x) (x >= '0' && x <= '9')
#define ud unsigned
#define _i128 __int128
LL read()
{
LL x = 0, f = 1; char c = gc();
while(!digit(c)){if (c == '-') f = -f; c = gc();}
while(digit(c)) x = (x << 3) + (x << 1) + c -'0', c = gc();
return x * f;
}
const int N = 4e7 + 5;
int n;
int a[N], f[N];
struct Node
{
int pos, l;
Node(){}
Node(int P, int L){pos = P, l = L;}
}q[N];
LL s[N];
_i128 sqr(_i128 x)
{
return x * x;
}
void Print(_i128 x)
{
if(!x) return ;
Print(x / 10);
printf("%d", x % 10);
}
int main()
{
n = read();int op = read();
if(op)
{
static const LL mod = 1 << 30;
static LL b[N];
LL x, y, z, m;
x = read(), y = read(),
z = read(), b[1] = read(),
b[2] = read(), m = read();
for (re int i = 3; i <= n; i++)
b[i] = ( x * b[i - 1] + y * b[i - 2] + z ) % mod;
LL lp = 0, p, l, r;
for (re int i = 1; i <= m; i++)
{
p = read(), l = read(), r = read();
for (re int j = lp + 1; j <= p; j++)
a[j] = b[j] % (r - l + 1) + l;
lp = p;
}
}
else
for (re int i = 1; i <= n; i++)
a[i] = read();
for (re int i = 1; i <= n; i++)
s[i] = s[i - 1] + a[i];
int h = 1, t = 0;
q[++t] = Node(0, 1);
for (re int i = 1; i <= n; i++)
{
while (h < t && q[h + 1].l <= i) h++;
f[i] = q[h].pos;
LL p = s[i] - s[f[i]];
int l = i, r = n + 1;
while(l + 1 < r)
{
int mid = (l + r) >> 1;
if(s[mid] - s[i] >= p)
r = mid;
else l = mid;
}
if (r == n + 1) continue;
while (h <= t && r <= q[t].l) t--;
q[++t] = Node(i, r);
}
_i128 ans = 0; int now = n;
while(now)
{
ans += sqr(s[now] - s[f[now]]);
now = f[now];
}
if(!ans) putchar('0');
else Print(ans);
putchar('\n');
return 0;
}