【题目链接】
【思路要点】
- 首先,我们显然需要令所有边为 A A A 进行一次询问,得到 s s s 到 t t t 的最短路。
- 我们可以从找到最短路上的一个点出发:
- 令所有与编号在 [ 1 , m i d ] [1,mid] [1,mid] 中的点相邻的边为 B B B ,其余边为 A A A ,通过判断最短路是否不变,我们可以知道 s s s 到 t t t 间是否存在一条不经过编号在 [ 1 , m i d ] [1,mid] [1,mid] 中的点的最短路。因此通过二分,我们可以在 L o g 2 N Log_2N Log2N 次操作内找到最短路上的一个点 r o o t root root 。
- 以 r o o t root root 为根,进行 B F S BFS BFS ,显然 s s s 和 t t t 中必然有一个是其间任何一条最短路上 B F S BFS BFS 序最大的点。令所有与 B F S BFS BFS 序在 [ m i d + 1 , N ] [mid+1,N] [mid+1,N] 中的点相邻的边为 B B B ,其余边为 A A A ,通过判断最短路是否不变,我们可以知道 s s s 和 t t t 中是否有至少一个的 B F S BFS BFS 序在 [ m i d + 1 , N ] [mid+1,N] [mid+1,N] 中。因此通过二分,我们可以在 L o g 2 N Log_2N Log2N 次操作内找到 s s s 和 t t t 中的一个。
- 不妨令找到了 s s s ,以 s s s 为根,进行 B F S BFS BFS ,同样通过二分,我们可以找到 t t t 。
- 时间复杂度 O ( M L o g N ) O(MLogN) O(MLogN) ,使用操作次数不超过 3 L o g 2 N + 1 3Log_2N+1 3Log2N+1 ,实测最多使用操作 51 51 51 次,得分 90 90 90 。
- 我们也可以从找到最短路上的一条边出发:
- 令所有与编号在 [ 1 , m i d ] [1,mid] [1,mid] 中的边为 B B B ,其余边为 A A A ,通过判断最短路是否不变,我们可以知道 s s s 到 t t t 间是否存在一条不经过编号在 [ 1 , m i d ] [1,mid] [1,mid] 中的边的最短路。因此通过二分,我们可以在 L o g 2 N Log_2N Log2N 次操作内找到最短路上的一条边 ( r o o t x , r o o t y ) (rootx,rooty) (rootx,rooty) 。
- 由于 ( r o o t x , r o o t y ) (rootx,rooty) (rootx,rooty) 在最短路上,有 d i s t ( s , r o o t x ) ≠ d i s t ( s , r o o t y ) dist(s,rootx)\ne dist(s,rooty) dist(s,rootx)̸=dist(s,rooty) 且 d i s t ( t , r o o t x ) ≠ d i s t ( t , r o o t y ) dist(t,rootx)\ne dist(t,rooty) dist(t,rootx)̸=dist(t,rooty) ,因此,以 r o o t x , r o o t y rootx,rooty rootx,rooty 为根,分别进行 B F S BFS BFS ,我们可以将点集分成两组,一组离 r o o t x rootx rootx 更近,另一组离 r o o t y rooty rooty 更近, s s s 和 t t t 一定分别在其中一组中,用上文中的二分解决即可。
- 时间复杂度 O ( M L o g N ) O(MLogN) O(MLogN) ,使用操作次数不超过 L o g 2 M + L o g 2 S + L o g 2 T + 1 ( S + T ≤ N ) Log_2M+Log_2S+Log_2T+1\ (S+T≤N) Log2M+Log2S+Log2T+1 (S+T≤N) ,实测最多使用操作 50 50 50 次,得分 100 100 100 。
【代码】
// 90 points Version - 3LogN + 1 #include "highway.h" #include<bits/stdc++.h> using namespace std; const int MAXN = 1e5 + 5; typedef long long ll; int n, m, costA, costB; int s, t, dis, root; int dist[MAXN], num[MAXN], q[MAXN]; vector <int> u, v; vector <int> a[MAXN], input; void work(int l, int r) { if (l == r) { root = l; return; } int mid = (l + r) / 2; for (int i = 0; i < m; i++) if ((u[i] >= 1 && u[i] <= mid) || (v[i] >= 1 && v[i] <= mid)) input[i] = 1; else input[i] = 0; if (ask(input) != 1ll * dis * costA) work(l, mid); else work(mid + 1, r); } void bfs(int from) { memset(dist, -1, sizeof(dist)); int l = 1, r = 1, timer = 1; q[1] = from, dist[from] = 0, num[from] = 1; while (l <= r) { int tmp = q[l++]; for (unsigned i = 0; i < a[tmp].size(); i++) if (dist[a[tmp][i]] == -1) { dist[a[tmp][i]] = dist[tmp] + 1; q[++r] = a[tmp][i]; num[a[tmp][i]] = ++timer; } } } void find_pair(int tn, vector <int> tu, vector <int> tv, int ta, int tb) { n = tn, costA = ta, costB = tb; m = tu.size(), u = tu, v = tv; input.resize(m); dis = ask(input) / costA; for (int i = 0; i < m; i++) { tu[i]++, u[i]++, tv[i]++, v[i]++; int x = tu[i], y = tv[i]; a[x].push_back(y); a[y].push_back(x); } work(1, n); bfs(root); int l = 2, r = n; while (l < r) { int mid = (l + r + 1) / 2; for (int i = 0; i < m; i++) if (num[u[i]] >= mid || num[v[i]] >= mid) input[i] = 1; else input[i] = 0; if (ask(input) != 1ll * dis * costA) l = mid; else r = mid - 1; } bfs(s = q[l]); static int p[MAXN], pos[MAXN]; int tot = 0; for (int i = 1; i <= n; i++) if (dist[i] == dis) p[i] = ++tot, pos[tot] = i; else p[i] = 0; l = 1, r = tot; while (l < r) { int mid = (l + r) / 2; for (int i = 0; i < m; i++) if ((p[u[i]] >= l && p[u[i]] <= mid) || (p[v[i]] >= l && p[v[i]] <= mid)) input[i] = 1; else input[i] = 0; if (ask(input) != 1ll * dis * costA) r = mid; else l = mid + 1; } t = pos[l]; answer(s - 1, t - 1); } // 100 points Version - LogS + LogT + LogM + 1 (S + T == N) #include "highway.h" #include<bits/stdc++.h> using namespace std; const int MAXN = 1e5 + 5; typedef long long ll; int n, m, costA, costB; int s, t, dis, rootx, rooty; int distx[MAXN], numx[MAXN], qx[MAXN]; int disty[MAXN], numy[MAXN], qy[MAXN]; vector <int> u, v; vector <int> a[MAXN], input; void work(int l, int r) { if (l == r) { rootx = u[l]; rooty = v[l]; return; } int mid = (l + r) / 2; for (int i = 0; i < m; i++) if (i <= mid) input[i] = 1; else input[i] = 0; if (ask(input) != 1ll * dis * costA) work(l, mid); else work(mid + 1, r); } void bfs(int from, int *dist, int *q, int *num) { for (int i = 1; i <= n; i++) dist[i] = -1; int l = 1, r = 1, timer = 1; q[1] = from, dist[from] = 0, num[from] = 1; while (l <= r) { int tmp = q[l++]; for (unsigned i = 0; i < a[tmp].size(); i++) if (dist[a[tmp][i]] == -1) { dist[a[tmp][i]] = dist[tmp] + 1; q[++r] = a[tmp][i]; num[a[tmp][i]] = ++timer; } } } void find_pair(int tn, vector <int> tu, vector <int> tv, int ta, int tb) { n = tn, costA = ta, costB = tb; m = tu.size(), u = tu, v = tv; input.resize(m); dis = ask(input) / costA; for (int i = 0; i < m; i++) { tu[i]++, u[i]++, tv[i]++, v[i]++; int x = tu[i], y = tv[i]; a[x].push_back(y); a[y].push_back(x); } work(0, m - 1); bfs(rootx, distx, qx, numx); bfs(rooty, disty, qy, numy); static int p[MAXN], pos[MAXN]; int tot = 0; for (int i = 1; i <= n; i++) if (distx[qx[i]] < disty[qx[i]]) p[qx[i]] = ++tot, pos[tot] = qx[i]; else p[qx[i]] = 0; int l = 1, r = tot; while (l < r) { int mid = (l + r + 1) / 2; for (int i = 0; i < m; i++) if (p[u[i]] >= mid || p[v[i]] >= mid) input[i] = 1; else input[i] = 0; if (ask(input) != 1ll * dis * costA) l = mid; else r = mid - 1; } s = pos[l]; tot = 0; for (int i = 1; i <= n; i++) if (disty[qy[i]] < distx[qy[i]]) p[qy[i]] = ++tot, pos[tot] = qy[i]; else p[qy[i]] = 0; l = 1, r = tot; while (l < r) { int mid = (l + r + 1) / 2; for (int i = 0; i < m; i++) if (p[u[i]] >= mid || p[v[i]] >= mid) input[i] = 1; else input[i] = 0; if (ask(input) != 1ll * dis * costA) l = mid; else r = mid - 1; } t = pos[l]; answer(s - 1, t - 1); }