题解
对于在河同侧的,直接累加进答案。
剩下的直接看做 m m 条线段(因为桥宽度为 1 1 ,最后加上就行).
然后当 k=1 k = 1 时:
就是找到一个 p p 来最小化
直接找这 2m 2 m 个点中位数即可.
然后当 k=2 k = 2 时:
每个人应该走离线段中点最近的那个桥。
于是先把线段按照线段的中点排序(也就是按 li+ri l i + r i 排序)。
从左往右枚举一个扫描线,它左边的走桥 1 1 ,右边的走桥,就转化成左右都是 k=1 k = 1 了,看看扫描线在哪的时候答案最小.
这时候左右分别排序显然时间复杂度无法承受,因此使用权值线段树,把点离散化,每次从右边线段树删两点,加到左边线段树来,然后统计一下左右的代价和是否小于 ans a n s
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#define px first
#define py second
typedef long long LL;
typedef std :: pair<int, int> P;
const int N = 1e5 + 20;
int k, n, cnt;
LL ans;
P a[N];
void solve1() {
static std :: vector<int> vec;
vec.clear();
for(int i = 1; i <= cnt; i ++) {
vec.push_back(a[i].px);
vec.push_back(a[i].py);
}
std :: sort(vec.begin(), vec.end());
for(int i = 1; i <= cnt; i ++) ans -= vec[i - 1];
for(int i = cnt + 1; i <= cnt << 1; i ++) ans += vec[i - 1];
printf("%lld\n", ans);
}
bool cmp(const P & a, const P & b) {
return a.px + a.py < b.px + b.py;
}
LL pos[N << 4], len;
struct Seg {
int sz[N << 4];
LL sum[N << 4];
Seg() {
memset(sz, 0, sizeof sz);
memset(sum, 0, sizeof sum);
}
int kth(int k, int l, int r, int rk) {
if(l == r) return l;
int mid = l + r >> 1;
if(rk <= sz[k << 1]) return kth(k << 1, l, mid, rk);
return kth(k << 1 | 1, mid + 1, r, rk - sz[k << 1]);
}
void modify(int k, int l, int r, int p, int val) {
if(l == r) {
sum[k] += 1ll * pos[l] * val;
sz[k] += val;
return ;
}
int mid = l + r >> 1;
if(p <= mid) modify(k << 1, l, mid, p, val);
else modify(k << 1 | 1, mid + 1, r, p, val);
sz[k] = sz[k << 1] + sz[k << 1 | 1];
sum[k] = sum[k << 1] + sum[k << 1 | 1];
}
LL querysz(int k, int l, int r, int L, int R) {
if(l > R || r < L) return 0;
if(L <= l && r <= R) return sz[k];
int mid = l + r >> 1;
return querysz(k << 1, l, mid, L, R) \
+ querysz(k << 1 | 1, mid + 1, r, L, R);
}
LL querysum(int k, int l, int r, int L, int R) {
if(l > R || r < L) return 0LL;
if(L <= l && r <= R) return sum[k];
int mid = l + r >> 1;
return querysum(k << 1, l, mid, L, R) \
+ querysum(k << 1 | 1, mid + 1, r, L, R);
}
} le, ri;
void solve2() {
std :: sort(a + 1, a + cnt + 1, cmp);
std :: sort(pos + 1, pos + len + 1);
len = std :: unique(pos + 1, pos + len + 1) - (pos + 1);
for(int i = 1; i <= cnt; i ++) {
a[i].px = std :: lower_bound(pos + 1, pos + len + 1, a[i].px) - pos;
a[i].py = std :: lower_bound(pos + 1, pos + len + 1, a[i].py) - pos;
ri.modify(1, 1, len, a[i].px, 1);
ri.modify(1, 1, len, a[i].py, 1);
}
LL ans2 = 1e18;
for(int i = 1; i <= cnt; i ++) {
le.modify(1, 1, len, a[i].px, 1);
le.modify(1, 1, len, a[i].py, 1);
ri.modify(1, 1, len, a[i].px, -1);
ri.modify(1, 1, len, a[i].py, -1);
int mid1 = le.kth(1, 1, len, i);
int mid2 = ri.kth(1, 1, len, cnt - i);
LL cl = le.querysz(1, 1, len, 1, mid1) * pos[mid1] - le.querysum(1, 1, len, 1, mid1) + \
le.querysum(1, 1, len, mid1 + 1, len) - le.querysz(1, 1, len, mid1 + 1, len) * pos[mid1];
LL cr = ri.querysz(1, 1, len, 1, mid2) * pos[mid2] - ri.querysum(1, 1, len, 1, mid2) + \
ri.querysum(1, 1, len, mid2 + 1, len) - ri.querysz(1, 1, len, mid2 + 1, len) * pos[mid2];
if(cl + cr < ans2) ans2 = cl + cr;
}
printf("%lld\n", ans + ans2);
}
int main() {
char s1[5], s2[5];
scanf("%d%d", &k, &n);
for(int i = 1, u, v; i <= n; i ++) {
scanf("%s%d%s%d", s1, &u, s2, &v);
if(* s1 == * s2) {
ans += std :: abs(u - v);
continue ;
}
if(u > v) std :: swap(u, v);
++ ans;
pos[++ len] = u;
pos[++ len] = v;
a[++ cnt] = P(u, v);
}
if(!cnt) return printf("%lld\n", ans), 0;
k == 1 ? solve1() : solve2();
return 0;
}