给定一个 n × m(1 <= m <= 3) 的点网格,网格的边上以及点上都有权值。
初始时所有点的权值都为 0 。
维护两种操作:
1. x1 y1 x2 y2 c 把从 (x1, y1) 到 (x2, y2) 的最短路上的所有节点的权值都增加 c 。保证最短路唯一。
2. x y 询问 (x, y) 的权值。
解:
为了方便,所有编号均从 0 开始。
对于操作1,我们不妨设 x1 <= x2 。
我们先考虑 (x1, y1) 和 (x2, y2) 之间的最短路可以由哪些简单的部分构成:
1. 允许经过整个网格,(x1, y1) 经过最短路走到 (x1, p)。
2. 仅允许经过 [x1, x2] × [0, m-1] ,(x1, p) 经过最短路走到 (x2, q)。
3. 允许经过整个网格,(x2, q) 经过最短路走到 (x2, y2)。
其中 0 <= p, q < m 。
允许经过整个网络,(x, y) 走到 (x, y') 的最短路由以下几个部分构成:
1. 仅允许经过 [0, x] × [0, m-1] ,(x, y) 经过最短路走到 (x, p)。
2. 仅允许经过 [x, n-1] × [0, m-1] ,(x, p) 经过最短路走到 (x, y')。
或者
1'. 仅允许经过 [x, n-1] × [0, m-1] ,(x, y) 经过最短路走到 (x, p)。
2'. 仅允许经过 [0, x] × [0, m-1] ,(x, p) 经过最短路走到 (x, y')。
其中 0 <= p < m 。
综合以上,我们可以把 (x1, y1) 和 (x2, y2) 之间的最短路 用至多 5 个如下特殊的最短路连接起来:
仅允许经过 [L, R] × [0, m-1] ,(p) 走到 (q) 的最短路。
其中,对于确定的区间[L, R],我们把 点(L, p) 简记为 (p),把 点(R, q) 简记为 (q+m)。
其中,0 <= p, q < 2*m。
这样的最短路可以考虑用线段树去维护。
考虑区间[L, R]所维护的信息:
仅经过 [L, R] × [0, m-1] 之内的点,(p) 走到 (q) 的最短路长度 dis[p][q]。
假设两个区间分别是 [L, mid] 和 [mid+1, R] ,如何合并两个区间的答案见如下代码:
1 void get(int k, int L, int R, int x, int y, node *res) 2 { 3 if (L == x && R == y) 4 { 5 memcpy(&res[k], &tree[k], sizeof(node)); 6 return; 7 } 8 int mid = (L+R)/2; 9 if (y <= mid) 10 { 11 get(k<<1, L, mid, x, y, res); 12 memcpy(&res[k], &res[k<<1], sizeof(node)); 13 } 14 else if (x > mid) 15 { 16 get(k<<1|1, mid+1, R, x, y, res); 17 memcpy(&res[k], &res[k<<1|1], sizeof(node)); 18 } 19 else 20 { 21 get(k<<1, L, mid, x, mid, res); 22 get(k<<1|1, mid+1, R, mid+1, y, res); 23 for (int i = 0; i < m; ++ i) 24 for (int j = i+1; j < m; ++ j) 25 { 26 res[k].dis[i][j] = res[k<<1].dis[i][j]; 27 res[k].dis[i+m][j+m] = res[k<<1|1].dis[i+m][j+m]; 28 for (int p = 0; p < m; ++ p) 29 for (int q = 0; q < m; ++ q) if (p != q) 30 { 31 res[k].dis[i][j] = min(res[k].dis[i][j], res[k<<1].dis[i][p+m]+res[k<<1|1].dis[p][q]+res[k<<1].dis[j][q+m]+rc[mid][p]+rc[mid][q]); 32 res[k].dis[i+m][j+m] = min(res[k].dis[i+m][j+m], res[k<<1|1].dis[p][i+m]+res[k<<1].dis[p+m][q+m]+res[k<<1|1].dis[q][j+m]+rc[mid][p]+rc[mid][q]); 33 } 34 res[k].dis[j][i] = res[k].dis[i][j]; 35 res[k].dis[j+m][i+m] = res[k].dis[i+m][j+m]; 36 } 37 for (int i = 0; i < m; ++ i) 38 for (int j = 0; j < m; ++ j) 39 { 40 res[k].dis[i][j+m] = INF; 41 for (int p = 0; p < m; ++ p) 42 res[k].dis[i][j+m] = min(res[k].dis[i][j+m], res[k<<1].dis[i][p+m]+res[k<<1|1].dis[p][j+m]+rc[mid][p]); 43 if (m == 3) 44 { 45 for (int p = 0; p < m; ++ p) 46 for (int q = 0; q < m; ++ q) if (p != q) 47 res[k].dis[i][j+m] = min(res[k].dis[i][j+m], res[k<<1].dis[i][p+m]+res[k<<1|1].dis[p][q]+res[k<<1].dis[q+m][(3-p-q)+m]+res[k<<1|1].dis[(3-p-q)][j+m]+rc[mid][0]+rc[mid][1]+rc[mid][2]); 48 } 49 res[k].dis[j+m][i] = res[k].dis[i][j+m]; 50 } 51 } 52 }
当我们得到了 (x1, y1) 和 (x2, y2) 之间的最短路长度之后,剩下的部分就是找到最短路具体是哪一条路。
注意到,任何一段最短路都能被分割成 O(log n) 段单个区间内两端点 (p) 和 (q) 的最短路。
寻找具体最短路的代码:
1 void go_path(int k, int L, int R, int x, int y, node *res, int p, int q, ull c) 2 { 3 if (L == x && R == y) 4 { 5 Tree[depth[k]][q].add(p < m ? place(L, p) : place(R, p-m), c); 6 return; 7 } 8 int mid = (L+R)/2; 9 if (y <= mid) 10 go_path(k<<1, L, mid, x, y, res, p, q, c); 11 else if (x > mid) 12 go_path(k<<1|1, mid+1, R, x, y, res, p, q, c); 13 else 14 { 15 if (p < m && q < m) 16 { 17 if (res[k].dis[p][q] == res[k<<1].dis[p][q]) 18 { 19 go_path(k<<1, L, mid, x, mid, res, p, q, c); 20 return; 21 } 22 for (int pp = 0; pp < m; ++ pp) 23 for (int qq = 0; qq < m; ++ qq) 24 if (res[k].dis[p][q] == res[k<<1].dis[p][pp+m]+res[k<<1|1].dis[pp][qq]+res[k<<1].dis[q][qq+m]+rc[mid][pp]+rc[mid][qq]) 25 { 26 go_path(k<<1, L, mid, x, mid, res, p, pp+m, c); 27 go_path(k<<1|1, mid+1, R, mid+1, y, res, pp, qq, c); 28 go_path(k<<1, L, mid, x, mid, res, q, qq+m, c); 29 return; 30 } 31 } 32 else if (p >= m && q >= m) 33 { 34 if (res[k].dis[p][q] == res[k<<1|1].dis[p][q]) 35 { 36 go_path(k<<1|1, mid+1, R, mid+1, y, res, p, q, c); 37 return; 38 } 39 for (int pp = 0; pp < m; ++ pp) 40 for (int qq = 0; qq < m; ++ qq) 41 if (res[k].dis[p][q] == res[k<<1|1].dis[pp][p]+res[k<<1].dis[pp+m][qq+m]+res[k<<1|1].dis[qq][q]+rc[mid][pp]+rc[mid][qq]) 42 { 43 go_path(k<<1|1, mid+1, R, mid+1, y, res, pp, p, c); 44 go_path(k<<1, L, mid, x, mid, res, pp+m, qq+m, c); 45 go_path(k<<1|1, mid+1, R, mid+1, y, res, qq, q, c); 46 return; 47 } 48 } 49 else 50 { 51 if (p > q) swap(p, q); 52 for (int r = 0; r < m; ++ r) 53 if (res[k].dis[p][q] == res[k<<1].dis[p][r+m]+res[k<<1|1].dis[r][q]+rc[mid][r]) 54 { 55 go_path(k<<1, L, mid, x, mid, res, p, r+m, c); 56 go_path(k<<1|1, mid+1, R, mid+1, y, res, r, q, c); 57 return; 58 } 59 if (m == 3) 60 { 61 for (int pp = 0; pp < m; ++ pp) 62 for (int qq = 0; qq < m; ++ qq) if (pp != qq) 63 if (res[k].dis[p][q] == res[k<<1].dis[p][pp+m]+res[k<<1|1].dis[pp][qq]+res[k<<1].dis[qq+m][(3-pp-qq)+m]+res[k<<1|1].dis[(3-pp-qq)][q]+rc[mid][0]+rc[mid][1]+rc[mid][2]) 64 { 65 go_path(k<<1, L, mid, x, mid, res, p, pp+m, c); 66 go_path(k<<1|1, mid+1, R, mid+1, y, res, pp, qq, c); 67 go_path(k<<1, L, mid, x, mid, res, qq+m, (3-pp-qq)+m, c); 68 go_path(k<<1|1, mid+1, R, mid+1, y, res, 3-pp-qq, q, c); 69 return; 70 } 71 } 72 } 73 } 74 }
我们对每个区间 [L, R] 可以先预处理出 以 (p) 为源的最短路树(如果不唯一则取任意一个,由于题目保证了最短路唯一,所以那些不唯一的部分一定不会出现在最短路中),其中0 <= p < 2*m。
而对于每一段单个区间的 (p) 和 (q) 的最短路,在这一段路中的每个点的权值都增加 c ,可以看做是区间[L, R]内,【以 (p) 为根的最短路树中,(q)到根之间所有节点增加 c】。
此时我们考虑询问,(x, y)的权值可以表示成若干个(至多O(log n)个)区间的增量之和,即若干个形如 【在某棵树中某节点的权值】 之和。
于是,我们只需要维护两个树上的操作:
1. 根到某个节点之间的一条链整体增加一个数;
2. 询问某个节点的值。
预处理出每个区间的每一棵最短路树的DFS序后,接下来的问题就变成了:维护一个数列,
1'. 区间增加一个数;
2'. 单点询问。
这个问题可以用线段树或者树状数组解决。
综合以上,整个算法的时间复杂度为 $O(n m^2 \log^2 n + q (m^4 \log n + m^2 \log^2 n) )$ 。