【题目链接】
【思路要点】
- 考虑 O ( N M ) O(NM) O(NM) 的暴力,对于每个询问,我们需要进行一次树形 d p dp dp 。
- d p dp dp 的状态大致是令某个点 i i i 取/不取,其子树内的最优权值和。
- 考虑优化,对于一个询问 ( x , y ) (x,y) (x,y) ,将路径 ( x , y ) (x,y) (x,y) 单独考虑,对于路径 ( x , y ) (x,y) (x,y) 上的每个点,其不在路径上的子树内的 d p dp dp 值是可以预处理的,并且,若我们将路径 ( x , y ) (x,y) (x,y) 拆分为 ( x , L c a ) , ( y , L c a ) (x,Lca),(y,Lca) (x,Lca),(y,Lca) ,对于路径记录路径首/尾处取/不取时的最优解,可以通过倍增预处理某一段的 d p dp dp 值合并后的结果。
- 具体细节可见代码。
- 时间复杂度 O ( N L o g N + M L o g N ) O(NLogN+MLogN) O(NLogN+MLogN) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 1e5 + 5; const int MAXLOG = 22; typedef long long ll; const ll INF = 1e18; template <typename T> void read(T &x) { x = 0; int f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0'; x *= f; } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } struct info {ll a[2][2]; }; info cipher() { info ans; ans.a[0][0] = INF; ans.a[0][1] = INF; ans.a[1][0] = INF; ans.a[1][1] = INF; return ans; } info operator + (info a, info b) { info ans = cipher(); for (int i = 0; i <= 1; i++) for (int j = 0; j <= 1; j++) { chkmin(ans.a[i][j], a.a[i][0] + b.a[1][j]); chkmin(ans.a[i][j], a.a[i][1] + b.a[0][j]); chkmin(ans.a[i][j], a.a[i][1] + b.a[1][j]); } return ans; } int n, m, depth[MAXN], father[MAXN][MAXLOG]; ll dp[MAXN][2], fp[MAXN][2], up[MAXN][2], val[MAXN]; info v[MAXN][MAXLOG]; vector <int> a[MAXN]; void getdp(int pos, int fa) { dp[pos][0] = 0, dp[pos][1] = val[pos]; for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) { getdp(a[pos][i], pos); dp[pos][0] += dp[a[pos][i]][1]; dp[pos][1] += min(dp[a[pos][i]][0], dp[a[pos][i]][1]); } for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) { fp[a[pos][i]][0] = dp[pos][0] - dp[a[pos][i]][1]; fp[a[pos][i]][1] = dp[pos][1] - min(dp[a[pos][i]][0], dp[a[pos][i]][1]); } } void getup(int pos, int fa) { ll tmp[2] = {dp[pos][0] + up[pos][1], dp[pos][1] + min(up[pos][0], up[pos][1])}; for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) { up[a[pos][i]][0] = tmp[0] - dp[a[pos][i]][1]; up[a[pos][i]][1] = tmp[1] - min(dp[a[pos][i]][0], dp[a[pos][i]][1]); getup(a[pos][i], pos); } } void work(int pos, int fa) { father[pos][0] = fa; depth[pos] = depth[fa] + 1; for (int i = 1; i < MAXLOG; i++) father[pos][i] = father[father[pos][i - 1]][i - 1]; for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) work(a[pos][i], pos); } int lca(int x, int y) { if (depth[x] < depth[y]) swap(x, y); for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[x][i]] >= depth[y]) x = father[x][i]; if (x == y) return x; for (int i = MAXLOG - 1; i >= 0; i--) if (father[x][i] != father[y][i]) { x = father[x][i]; y = father[y][i]; } return father[x][0]; } info formdp(int pos, int tx) { info ans = cipher(); ans.a[tx][tx] = dp[pos][tx]; return ans; } info getpath(int x, int f, int tx) { info ans = formdp(x, tx); for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[x][i]] > depth[f]) { ans = ans + v[x][i]; x = father[x][i]; } return ans; } info Reverse(info a) { swap(a.a[0][1], a.a[1][0]); return a; } info getlca(int x, int y, int f) { ll tmp[2] = {dp[f][0] + up[f][1], dp[f][1] + min(up[f][0], up[f][1])}; for (int i = MAXLOG - 1; i >= 0; i--) { if (depth[father[x][i]] > depth[f]) x = father[x][i]; if (depth[father[y][i]] > depth[f]) y = father[y][i]; } tmp[0] -= dp[x][1] + dp[y][1]; tmp[1] -= min(dp[x][0], dp[x][1]) + min(dp[y][0], dp[y][1]); info ans = cipher(); ans.a[0][0] = tmp[0]; ans.a[1][1] = tmp[1]; return ans; } info getlcb(int x, int y, int tx) { ll tmp[2] = {dp[x][0] + up[x][1], dp[x][1] + min(up[x][0], up[x][1])}; for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[y][i]] > depth[x]) y = father[y][i]; tmp[0] -= dp[y][1]; tmp[1] -= min(dp[y][0], dp[y][1]); info ans = cipher(); ans.a[tx][tx] = tmp[tx]; return ans; } ll getans(info a) { ll ans = INF; chkmin(ans, a.a[0][0]); chkmin(ans, a.a[1][0]); chkmin(ans, a.a[0][1]); chkmin(ans, a.a[1][1]); if (ans >= INF) return -1; else return ans; } int main() { freopen("defense.in", "r", stdin); freopen("defense.out", "w", stdout); char s[5]; read(n), read(m), scanf("\n%s", s + 1); for (int i = 1; i <= n; i++) read(val[i]); for (int i = 1; i <= n - 1; i++) { int x, y; read(x), read(y); a[x].push_back(y); a[y].push_back(x); } work(1, 0); getdp(1, 0); getup(1, 0); for (int i = 1; i <= n; i++) { v[i][0] = cipher(); v[i][0].a[0][0] = fp[i][0]; v[i][0].a[1][1] = fp[i][1]; } for (int p = 1; p < MAXLOG; p++) for (int i = 1; i <= n; i++) v[i][p] = v[i][p - 1] + v[father[i][p - 1]][p - 1]; for (int i = 1; i <= m; i++) { int x, tx, y, ty; read(x), read(tx), read(y), read(ty); if (depth[x] > depth[y]) { swap(x, y); swap(tx, ty); } int z = lca(x, y); if (x == z) { info ax = getlcb(x, y, tx); info ay = getpath(y, x, ty); info ans = ax + Reverse(ay); printf("%lld\n", getans(ans)); } else { info ax = getpath(x, z, tx); info ay = getpath(y, z, ty); info az = getlca(x, y, z); info ans = ax + az + Reverse(ay); printf("%lld\n", getans(ans)); } } return 0; }