Description
有一排n个格子,一开始两个棋子在a和b。一个长度为m的操作序列要求第i次要将一个棋子移动到第x[i]个格子。一次移动代价为两格子之间距离,最小化代价之和
n
≤
2
e
5
n\le2e5
n≤2e5
Solution
非常套路的dp,设f[i,j]表示一个棋子在x[i],另一个在j的最小答案,保证j<x[i]。
转移的话我们往后推看是哪个棋子移动到了x[i+1]就好了,这样直接做状态数O(n2)的,转移O(1)
然后就可以发现每一次操作等价于区间加一个数,单点修改一个数,求区间最小值。那么拆掉绝对值,一颗线段树维护f[i]-i,一棵线段树维护f[i]+i就可以了
Solution
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define rep(i,st,ed) for (int i=st;i<=ed;++i)
#define ls (now<<1)
#define rs (now<<1|1)
typedef long long LL;
const LL INF=1e15;
const int N=200005;
struct SegTree {
LL min[N<<2],tag[N<<2];
void push_down(int now) {
if (!tag[now]) return ;
LL w=tag[now]; tag[now]=0;
tag[ls]+=w,tag[rs]+=w;
min[ls]+=w,min[rs]+=w;
}
void add(int now,int tl,int tr,int l,int r,LL v) {
if (r<l) return ;
if (tl>=l&&tr<=r) return (void) (min[now]+=v,tag[now]+=v);
int mid=(tl+tr)>>1;
push_down(now);
if (l<=mid) add(ls,tl,mid,l,r,v);
if (mid+1<=r) add(rs,mid+1,tr,l,r,v);
min[now]=std:: min(min[ls],min[rs]);
}
void set(int now,int tl,int tr,int x,LL v) {
if (tl==tr) return (void) (min[now]=v);
int mid=(tl+tr)>>1;
push_down(now);
if (x<=mid) set(ls,tl,mid,x,v);
else set(rs,mid+1,tr,x,v);
min[now]=std:: min(min[ls],min[rs]);
}
LL query(int now,int tl,int tr,int l,int r) {
if (r<l) return INF;
if (tl>=l&&tr<=r) return min[now];
int mid=(tl+tr)>>1; LL qx=INF,qy=INF;
push_down(now);
if (l<=mid) qx=query(ls,tl,mid,l,r);
if (mid+1<=r) qy=query(rs,mid+1,tr,l,r);
return std:: min(qx,qy);
}
void build(int now,int tl,int tr) {
min[now]=INF;
if (tl==tr) return ;//(void) (min[now]=(opt)?(-tl):(tl));
int mid=(tl+tr)>>1;
build(ls,tl,mid),build(rs,mid+1,tr);
}
void change(int now,int tl,int tr) {
if (tl==tr) return (void) (min[now]-=tl);
int mid=(tl+tr)>>1;
push_down(now);
change(ls,tl,mid),change(rs,mid+1,tr);
min[now]=std:: min(min[ls],min[rs]);
}
void debug(int now,int tl,int tr) {
if (tl==tr) return (void) (printf("%lld ", min[now]));
int mid=(tl+tr)>>1;
push_down(now);
debug(ls,tl,mid),debug(rs,mid+1,tr);
}
} T1,T2;
LL x[N];
int read() {
int x=0,v=1; char ch=getchar();
for (;ch<'0'||ch>'9';v=(ch=='-')?(-1):v,ch=getchar());
for (;ch<='9'&&ch>='0';x=x*10+ch-'0',ch=getchar());
return x*v;
}
int main(void) {
freopen("data.in","r",stdin);
freopen("myp.out","w",stdout);
LL n=read(),T=read(),a=read(),b=read();
rep(i,1,T) x[i]=read();
T1.build(1,1,n),T2.build(1,1,n);
T1.set(1,1,n,a,abs(b-x[1])+a),T1.set(1,1,n,b,abs(a-x[1])+b);
T2.set(1,1,n,a,abs(b-x[1])-a),T2.set(1,1,n,b,abs(a-x[1])-b);
// T1.debug(1,1,n),puts("");
rep(i,2,T) {
LL d=abs(x[i]-x[i-1]);
LL ra=T1.query(1,1,n,x[i],n)-x[i];
LL rb=x[i]+T2.query(1,1,n,1,x[i]);
LL rc=T1.query(1,1,n,x[i],x[i])-x[i]+d;
LL res=std:: min(ra,std:: min(rc,rb));
T1.tag[1]+=d,T2.tag[1]+=d;
T1.set(1,1,n,x[i-1],res+x[i-1]),T2.set(1,1,n,x[i-1],res-x[i-1]);
// T1.debug(1,1,n),puts("");
} T1.change(1,1,n);
printf("%lld", T1.min[1]);
return 0;
}