线段树好题(get新用法)
线段树内维护四个值:每段区间左上到右上,左上到右下,左下到右上,左下到右下的最短路
struct node {
int dis_1, dis_2, dis_3, dis_4;
//dis_1 左上到右上
//dis_2 左上到右下
//dis_3 左下到右上
//dis_4 左下到右下
}tree[N << 2];
显而易见的 pushup
void pushup(int u) {
tree[u].dis_1 = min(inf, min(tree[u << 1].dis_1 + tree[u << 1 | 1].dis_1, tree[u << 1].dis_2 + tree[u << 1 | 1].dis_3) + 1);
tree[u].dis_2 = min(inf, min(tree[u << 1].dis_2 + tree[u << 1 | 1].dis_4, tree[u << 1].dis_1 + tree[u << 1 | 1].dis_2) + 1);
tree[u].dis_3 = min(inf, min(tree[u << 1].dis_3 + tree[u << 1 | 1].dis_1, tree[u << 1].dis_4 + tree[u << 1 | 1].dis_3) + 1);
tree[u].dis_4 = min(inf, min(tree[u << 1].dis_4 + tree[u << 1 | 1].dis_4, tree[u << 1].dis_3 + tree[u << 1 | 1].dis_2) + 1);
}
上代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 200010, inf = 1e9 + 7;;
int n, q;
char c[3][N];
struct node {
int dis_1, dis_2, dis_3, dis_4;
//dis_1 左上到右上
//dis_2 左上到右下
//dis_3 左下到右上
//dis_4 左下到右下
}tree[N << 2];
void pushup(int u) {
tree[u].dis_1 = min(inf, min(tree[u << 1].dis_1 + tree[u << 1 | 1].dis_1, tree[u << 1].dis_2 + tree[u << 1 | 1].dis_3) + 1);
tree[u].dis_2 = min(inf, min(tree[u << 1].dis_2 + tree[u << 1 | 1].dis_4, tree[u << 1].dis_1 + tree[u << 1 | 1].dis_2) + 1);
tree[u].dis_3 = min(inf, min(tree[u << 1].dis_3 + tree[u << 1 | 1].dis_1, tree[u << 1].dis_4 + tree[u << 1 | 1].dis_3) + 1);
tree[u].dis_4 = min(inf, min(tree[u << 1].dis_4 + tree[u << 1 | 1].dis_4, tree[u << 1].dis_3 + tree[u << 1 | 1].dis_2) + 1);
}
void build(int u, int x, int y) {
if(x == y) {
tree[u].dis_1 = tree[u].dis_2 = tree[u].dis_3 = tree[u].dis_4 = inf;
if(c[1][x] == '.') tree[u].dis_1 = 0;
if(c[2][x] == '.') tree[u].dis_4 = 0;
if(c[1][x] == '.' && c[2][x] == '.') tree[u].dis_2 = tree[u].dis_3 = 1;
return ;
}
int mid = (x + y) >> 1;
build(u << 1, x, mid), build(u << 1 | 1, mid + 1, y);
pushup(u);
}
node query(int u, int x, int y, int a, int b) {
if(a <= x && b >= y) return tree[u];
int mid = (x + y) >> 1;
if(b <= mid) return query(u << 1, x, mid, a, b);
if(a > mid) return query(u << 1 | 1, mid + 1, y, a, b);
node t, t_l = query(u << 1, x, mid, a, b), t_r = query(u << 1 | 1, mid + 1, y, a, b);
t.dis_1 = min(inf, min(t_l.dis_1 + t_r.dis_1, t_l.dis_2 + t_r.dis_3) + 1);
t.dis_2 = min(inf, min(t_l.dis_2 + t_r.dis_4, t_l.dis_1 + t_r.dis_2) + 1);
t.dis_3 = min(inf, min(t_l.dis_3 + t_r.dis_1, t_l.dis_4 + t_r.dis_3) + 1);
t.dis_4 = min(inf, min(t_l.dis_4 + t_r.dis_4, t_l.dis_3 + t_r.dis_2) + 1);
return t;
}
int main () {
memset(tree, -1, sizeof(tree));
scanf("%d %d", &n, &q);
for(int i = 1; i <= 2; i ++ ) for(int j = 1; j <= n; j ++ ) cin >> c[i][j];
build(1, 1, n);
while(q -- ) {
int u, v; scanf("%d %d", &u, &v);
int x_1 = u / (n + 1) + 1, y_1 = (u - 1) % n + 1;
int x_2 = v / (n + 1) + 1, y_2 = (v - 1) % n + 1;
if(y_1 > y_2) swap(x_1, x_2), swap(y_1, y_2);
//printf("x_1:%d y_1:%d x_2:%d y_2:%d \n", x_1, y_1, x_2, y_2);
node ans = query(1, 1, n, y_1, y_2);
if((x_1 & 1) && (x_2 & 1)) printf("%d\n", ans.dis_1 == inf ? -1 : ans.dis_1);
if((x_1 & 1) && !(x_2 & 1)) printf("%d\n", ans.dis_2 == inf ? -1 : ans.dis_2);
if(!(x_1 & 1) && (x_2 & 1)) printf("%d\n", ans.dis_3 == inf ? -1 : ans.dis_3);
if(!(x_1 & 1) && !(x_2 & 1)) printf("%d\n", ans.dis_4 == inf ? -1 : ans.dis_4);
}
return 0;
}