斜率优化 DP,又叫凸壳优化 (Convex Hull Trick) 。它是用来求解
d
p
[
i
]
=
max
j
∈
S
(
i
)
f
(
i
)
+
k
(
j
)
x
(
i
)
+
b
(
j
)
d p[i]=\max _{j \in S(i)} f(i)+k(j) x(i)+b(j)
dp[i]=maxj∈S(i)f(i)+k(j)x(i)+b(j) 并且满足一定条件的问题。
比如,满足右边和
d
p
[
j
]
d p[j]
dp[j] 无关,并且
S
(
i
)
S(i)
S(i) 为全集。或者
S
(
i
)
=
{
j
∣
j
<
i
}
S(i)=\{j \mid j<i\}
S(i)={j∣j<i} or
S
(
i
)
=
{
j
∣
j
>
i
}
S(i)=\{j \mid j>i\}
S(i)={j∣j>i} 等等。 其中max也可以改为
min
\min
min 。
先考虑一个简化的问题:
d
p
[
i
]
=
max
j
∈
S
(
i
)
f
(
i
)
+
k
(
j
)
x
(
i
)
+
b
(
j
)
d p[i]=\max _{j \in S(i)} f(i)+k(j) x(i)+b(j)
dp[i]=maxj∈S(i)f(i)+k(j)x(i)+b(j)
d
p
\mathrm{dp}
dp 式子右边和
d
p
[
j
]
\mathrm{dp}[\mathrm{j}]
dp[j] 无关,且
S
(
i
)
S(i)
S(i) 为全集。
对于固定的
i
i
i 来说
f
(
i
)
,
x
(
i
)
f(i), x(i)
f(i),x(i) 都是常数,将
f
(
i
)
f(i)
f(i) 忽略,
x
(
i
)
x(i)
x(i) 记为
x
,
d
p
[
i
]
x , d p[i]
x,dp[i] 记为
y
y
y ,我们可以将转移方程写成
y
=
max
j
k
(
j
)
x
+
b
(
j
)
y=\max _{j} k(j) x+b(j)
y=maxjk(j)x+b(j)
对于每个
j
j
j 来说,将
x
x
x 看成变量,
l
j
(
x
)
=
k
(
j
)
x
+
b
(
j
)
l_{j}(x)=k(j) x+b(j)
lj(x)=k(j)x+b(j) 是一条直线。而上式就是在这些直线中
x
=
x
(
i
)
x=x(i)
x=x(i) 时, 对应
y
y
y 值最大的。即
y
(
i
)
=
max
j
l
j
(
x
(
i
)
)
y(i)=\max _{j} l_{j}(x(i))
y(i)=maxjlj(x(i))
如果该式在
j
=
p
j=p
j=p 时取最大,说明
l
p
(
x
(
i
)
)
≥
l
q
(
x
(
i
)
)
,
∀
q
l_{p}(x(i)) \geq l_{q}(x(i)), \forall q
lp(x(i))≥lq(x(i)),∀q
用图来表示,对于任意的
i
i
i要求
y
(
i
)
y(i)
y(i)的最大值那么一定是落在蓝色的直线上面的,而维护的蓝色的这些直线即为凸包。
这里提一下如何在加边的过程中维护凸包?
给出一种确定方法,先将直线按照斜率从小到大排序或者从大到小排序(越来越斜,注意正负,对于
k
k
k为负数的直线,
k
k
k越小越斜,对于
k
k
k为正数的直线,
k
k
k越大)再依次加入
对于图中的直线先加入绿
h
h
h再加入黑色
f
f
f再加入紫色
g
g
g,每次加入新的直线时算出它与已经加入的直线中的倒数第一条的交点的横坐标,比如这里黑色与紫色的交点为
G
G
G,然后再拿这个坐标与已经加入的最后一条直线和倒数第二条直线的交点的横坐标比较,比如这里黑色和绿色的交点为
H
H
H如果G的横坐标小于等于H的横坐标,那么就把最后一条直线删除,即把黑线从
A
−
>
G
A->G
A−>G的部分删除,可以看出加入紫色后构成下凸包的区域应该为C->G->H->F
然后对于每一个 x ( i ) x(i) x(i) ,二分得到所在线段,即可得到答案,复杂度为 O ( n log n ) O(n \log n) O(nlogn)。或者 x ( i ) x(i) x(i) 按从小到大顺序,按序判断,移动指针,如果 x ( i ) x(i) x(i) 的次序可以 O ( n ) O(n) O(n) 得到,则复杂度为 O ( n ) O(n) O(n) 。
Codeforces Round #816 (Div. 2) E
#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=1e6+5;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
struct node{
int dis,pos;
bool operator<(const node&x)const{
return x.dis<dis;
}
};
struct Line { //存直线
int k, b; //y=kx+b
Line() {}
Line(int k, int b) : k(k), b(b) {}
//求两直线交点横坐标
double intersect(Line l) {
//交点
double db = l.b - b;
double dk = k - l.k;
return db / dk;
}
int operator () (int x) {
return k * x + b;
}
};
struct ConvexHullTrick{ //凸壳优化
vector<double>points; //存交点
vector<Line>lines; //存直线
//每个直线范围是上一个直线的points[i-1]到自己的points[i]的范围
int size(){
return points.size();
}
void reset(){
points.clear();
lines.clear();
}
void init(Line l){
points.push_back(-inf);
}
void addLine(Line l){
if(points.size()==0){
points.push_back(-inf);
lines.push_back(l);
return;
}
/*
加入的直线和倒数第二个直线的交点,小于最后一个直线和倒数第二个直线的交点
那么最后一条直线就可以直接丢掉了
*/
while(lines.size()>=2&&
l.intersect(lines[lines.size() - 2]) <= points.back()){
points.pop_back();
lines.pop_back();
}
points.push_back(l.intersect(lines.back()));
lines.push_back(l);
}
int query(int x,int id){
return lines[id](x);
}
int query(int x){
int id=upper_bound(points.begin(),points.end(),x)-points.begin()-1;
return lines[id](x);
}
};
vector<pair<int,int>>vec[maxn];
int n,m,k;
int dist[maxn];
int vis[maxn];
void dijkstra(){
memset(vis,0,sizeof(vis));
priority_queue<node>q;
for(int i=1;i<=n;i++){
q.push({dist[i],i});
}
while(!q.empty()){
auto[distance,u]=q.top();
q.pop();
if(vis[u]) continue;
vis[u]=1;
for(auto[v,w]:vec[u]){
if(dist[v]>distance+w){
dist[v]=distance+w;
if(!vis[v]){
q.push({dist[v],v});
}
}
}
}
}
void solve(){
cin>>n>>m>>k;
for(int i=1;i<=m;i++){
int u,v,w;
cin>>u>>v>>w;
vec[u].push_back({v,w});
vec[v].push_back({u,w});
}
memset(dist,inf,sizeof(dist));
dist[1]=0;
dijkstra();
ConvexHullTrick cht;
while(k--){
cht.reset();
for(int u=1;u<=n;u++){
cht.addLine({-2*u,dist[u]+u*u});
/*
k=-2u b=dist[u]+u*u
y=kx+b
加入新的边,要求的是越来越斜,因为这里的u是单调递增的所以
斜率-2u就是越来越斜的
*/
}
for(int v=1;v<=n;v++){
dist[v]=cht.query(v)+v*v;
}
dijkstra();
}
for(int i=1;i<=n;i++){
cout<<dist[i]<<" ";
}
}
signed main(){
int t=1;
//cin>>t;
while(t--){
solve();
}
}
在斜率k(j)单调递增,并且
x
(
i
)
x(i)
x(i) 单调递增的时候我们可以优化掉一个
l
o
g
log
log,因为不用对直线排序所以排序的
l
o
g
log
log省了,此外
x
(
i
)
x(i)
x(i)也是单调递增的所以我们找的线段一定是越来越往后的,这个可以移动指针来判断,符合这种情况的时候又叫做单调队列优化斜率DP,可以用单调队列的写法完成,但是为了统一处理更加方便,这里我们还是用凸包的方法解决,ConvexHullTrick里面新增一个queuequery函数,通过移动指针sta来判断是否到了
x
(
i
)
x(i)
x(i)所在的线段上,注意边界即可
任务安排2
#include<bits/stdc++.h>
#define int long long
#define io ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
const int maxn=1e6+5;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
struct Line { //存直线
int k, b; //y=kx+b
Line() {}
Line(int k, int b) : k(k), b(b) {}
//求两直线交点横坐标
double intersect(Line l) {
//交点
double db = l.b - b;
double dk = k - l.k;
return db / dk;
}
int operator () (int x) {
return k * x + b;
}
};
struct ConvexHullTrick{ //凸壳优化
int sta=0;
vector<double>points; //存交点
vector<Line>lines; //存直线
//每个直线范围是上一个直线的points[i-1]到自己的points[i]的范围
int size(){
return points.size();
}
void reset(){
sta=0;
points.clear();
lines.clear();
}
void init(Line l){
points.push_back(-inf);
}
void addLine(Line l){
if(points.size()==0){
points.push_back(-inf);
lines.push_back(l);
return;
}
/*
加入的直线和倒数第二个直线的交点,小于最后一个直线和倒数第二个直线的交点
那么最后一条直线就可以直接丢掉了
*/
while(lines.size()>=2&&
l.intersect(lines[lines.size() - 2]) <= points.back()){
points.pop_back();
lines.pop_back();
}
points.push_back(l.intersect(lines.back()));
lines.push_back(l);
}
int query(int x,int id){
return lines[id](x);
}
int query(int x){
int id=upper_bound(points.begin(),points.end(),x)-points.begin()-1;
return lines[id](x);
}
int queuequery(int x){
while(sta<points.size()&&points[sta]<=x){
sta++;
}
sta--;
return query(x,sta);
}
};
int n,s;
int t[maxn],c[maxn];
int st[maxn],sc[maxn];
int dp[maxn];
void solve(){
cin>>n>>s;
for(int i=1;i<=n;i++){
cin>>st[i]>>sc[i];
st[i]+=st[i-1];
sc[i]+=sc[i-1];
}
ConvexHullTrick cht;
for(int i=1;i<=n;i++){
if(i==1){
dp[i]=st[i]*sc[i]+s*sc[n];
}
else{
dp[i]=st[i]*sc[i]+s*sc[n];
dp[i]=min(st[i]*sc[i]+s*sc[n]+cht.queuequery(st[i]),dp[i]);
}
cht.addLine({-sc[i],dp[i]-s*sc[i]});
}
cout<<dp[n]<<"\n";
}
signed main(){
int t=1;
//cin>>t;
while(t--){
solve();
}
}