题目
题目概要
有
n
n
n 个点,你要把它们分成两个子序列,使得两个子序列中相邻的点的曼哈顿距离之和最小。输出这个值。
数据范围与提示
多组数据,
T
≤
1
0
3
,
n
≤
1
0
5
T\le 10^3,\;n\le 10^5
T≤103,n≤105 并保证
∑
n
≤
3
×
1
0
5
\sum n\le 3\times 10^5
∑n≤3×105 。
思路
容易想到一个 n 2 n^2 n2 的 d p \tt dp dp,用 f ( i , j ) f(i,j) f(i,j) 表示两个子序列的末尾节点分别是 i , j i,j i,j,前 max ( i , j ) \max(i,j) max(i,j) 个元素已经全部划分完了。转移比较简单,考虑最后一个元素划分给了谁就行。
状态数高达 n 2 n^2 n2,这可不是个好兆头。考虑 一次性多转移几步。转移的本质是考虑 i , j i,j i,j 的大小关系变化。所以完全可以 f ( i + 1 , i ) f(i+1,i) f(i+1,i) 考虑 i + 1 i+1 i+1 从 i + 2 i+2 i+2 一直连接到 j ( j ≥ i + 1 ) j(j\ge i+1) j(j≥i+1),然后 i i i 连接了 j + 1 j+1 j+1 转移到了 f ( j , j + 1 ) f(j,j+1) f(j,j+1) 。
如果记
g
(
i
)
=
f
(
i
+
1
,
i
)
g(i)=f(i+1,i)
g(i)=f(i+1,i),再记
d
i
s
(
i
,
j
)
dis(i,j)
dis(i,j) 为第
i
i
i 个点和第
j
j
j 个点的曼哈顿距离,那么有转移式
g
(
i
)
=
min
j
=
0
i
−
1
[
g
(
j
)
+
d
i
s
(
j
,
i
+
1
)
+
∑
x
=
j
+
1
i
−
1
d
i
s
(
x
,
x
+
1
)
]
g(i)=\min_{j=0}^{i-1}\left[g(j)+dis(j,i+1)+\sum_{x=j+1}^{i-1}dis(x,x+1)\right]
g(i)=j=0mini−1[g(j)+dis(j,i+1)+x=j+1∑i−1dis(x,x+1)]
令
d
i
s
(
0
,
i
)
=
0
dis(0,i)=0
dis(0,i)=0 表示子序列的第一个点没有距离产生。然后你发现
∑
x
=
j
+
1
i
−
1
d
i
s
(
x
,
x
+
1
)
\sum_{x=j+1}^{i-1}dis(x,x+1)
∑x=j+1i−1dis(x,x+1) 也是一个只跟
j
,
i
j,i
j,i 有关系的值——事实上就是前缀和。用
s
(
i
)
=
∑
x
=
0
i
d
i
s
(
x
,
x
+
1
)
s(i)=\sum_{x=0}^{i}dis(x,x+1)
s(i)=∑x=0idis(x,x+1) 来化简,可知
g
(
i
)
=
min
j
=
0
i
−
1
[
g
(
j
)
+
d
i
s
(
j
,
i
+
1
)
−
s
(
j
)
]
+
s
(
i
−
1
)
g(i)=\min_{j=0}^{i-1}\big[g(j)+dis(j,i+1)-s(j)\big]+s(i-1)
g(i)=j=0mini−1[g(j)+dis(j,i+1)−s(j)]+s(i−1)
唯一难搞的就是这个 d i s ( j , i + 1 ) dis(j,i+1) dis(j,i+1) 。考虑去掉绝对值。分四类情况讨论即可。然后 g ( i ) g(i) g(i) 转移式只跟 i , j i,j i,j 有关,相当于一个矩形(即四类偏序情况)求 min \min min 的问题。排序后树状数组即可。
g g g 从自己转移而来,像极了分治 F F T \tt FFT FFT,我们就叫它 c d q cdq cdq 分治 吧。
时间复杂度 O ( n log 2 n ) \mathcal O(n\log^2n) O(nlog2n) 。
代码
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long int_;
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0'||c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c&&c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
const int_ infty = (1ll<<60)-1;
const int MaxN = 100005;
struct BIT{
int_ c[MaxN];
int_ query(int x){
int_ res = infty;
for(int i=x; i; i-=(i&-i))
res = min(res,c[i]);
return res;
}
void modify(int_ v,int x,int n){
for(int i=x; i<=n; i+=(i&-i))
c[i] = min(c[i],v);
}
void clear(int n){
for(int i=1; i<=n; ++i)
c[i] = infty;
}
};
struct Point{
int x, y; int haxi, id;
void input(){
x = readint(), y = readint();
}
int disTo(const Point &t){
return abs(x-t.x)+abs(y-t.y);
}
bool operator < (const Point &t) const {
return id < t.id;
}
static bool cmpX(const Point &a,const Point &b){
return a.x < b.x;
}
static bool cmpY(const Point &a,const Point &b){
return a.y < b.y;
}
};
Point p[MaxN];
int dis(int i,int j){
return p[i].disTo(p[j]);
}
BIT pre, suf; // relationship of y
int_ g[MaxN], s[MaxN];
void solve(int l,int r){
if(l == r) return ;
int mid = (l+r)>>1;
solve(l,mid); // left part first
sort(p+l,p+r+2,Point::cmpY);
p[l].haxi = 1; // hash it
for(int i=l+1; i<=r+1; ++i)
if(p[i].y != p[i-1].y)
p[i].haxi = p[i-1].haxi+1;
else p[i].haxi = p[i-1].haxi;
sort(p+l,p+r+2,Point::cmpX);
int o = r-l+2; // length
pre.clear(o), suf.clear(o);
for(int i=l; i<=r+1; ++i){
if(p[i].id == mid+1) continue;
if(p[i].id <= mid){ // left
//printf("%d: (%d, %d) as left\n",p[i].id,p[i].x,p[i].y);
//printf("pre insert %d: %lld\n",p[i].haxi,g[p[i].id]-s[p[i].id]-p[i].x-p[i].y);
pre.modify(
g[p[i].id]-s[p[i].id]
-p[i].x-p[i].y,
p[i].haxi, o
);
suf.modify(
g[p[i].id]-s[p[i].id]
-p[i].x+p[i].y,
o+1-p[i].haxi, o
);
}
else{
g[p[i].id-1] = min(
g[p[i].id-1],
p[i].x+p[i].y
+pre.query(p[i].haxi)
+s[p[i].id-2]
);
//printf("update %d: %lld\n",p[i].id-1,p[i].x+p[i].y+pre.query(p[i].haxi)+s[p[i].id-2]);
g[p[i].id-1] = min(
g[p[i].id-1],
p[i].x-p[i].y
+suf.query(o-p[i].haxi)
+s[p[i].id-2]
);
}
}
pre.clear(o), suf.clear(o);
for(int i=r+1; i>=l; --i){
if(p[i].id == mid+1) continue;
if(p[i].id <= mid){ // left
pre.modify(
g[p[i].id]-s[p[i].id]
+p[i].x-p[i].y,
p[i].haxi, o
);
suf.modify(
g[p[i].id]-s[p[i].id]
+p[i].x+p[i].y,
o+1-p[i].haxi, o
);
}
else{
g[p[i].id-1] = min(
g[p[i].id-1],
-p[i].x+p[i].y
+pre.query(p[i].haxi)
+s[p[i].id-2]
);
g[p[i].id-1] = min(
g[p[i].id-1],
-p[i].x-p[i].y
+suf.query(o-p[i].haxi)
+s[p[i].id-2]
);
}
}
sort(p+l,p+r+2);
solve(mid+1,r); // right part
}
int main(){
for(int T=readint(); T; --T){
int n = readint();
for(int i=1; i<=n; ++i)
p[i].input(), p[i].id = i;
if(n <= 2){
puts("0"); continue;
}
for(int i=1; i<n; ++i){
s[i] = s[i-1]+dis(i,i+1);
g[i] = s[i-1]; // from zero
//printf("s[%d] = %lld\n",i,s[i]);
}
g[n] = s[n-1]; // from zero
solve(1,n-1); // ans = g(n)
int_ ans = g[n];
for(int i=1; i<n; ++i){
//printf("g[%d] = %lld\n",i,g[i]);
ans = min(ans,g[i]-s[i]+s[n-1]);
}
printf("%lld\n",ans);
}
return 0;
}