题目
题意:给出一个树,给出k个关键点,这k个点每一对点的路径和,最小值以及最大值。
做法:首先我们看见多组询问并且k和n同阶,应该就知道是虚树的处理了。
首先考虑怎么树上dp。
设
d
p
[
u
]
dp[u]
dp[u]表示当前
u
u
u为根的子树中关键点之间的和,接下来按照每一个子树的遍历顺序可以得到总的代价可以通过如下关系得到:
a
n
s
+
=
(
d
p
[
u
]
+
s
z
[
u
]
∗
d
i
s
)
∗
s
z
[
v
]
+
d
p
[
v
]
∗
s
z
[
u
]
;
ans += (dp[u] + sz[u] * dis) * sz[v] + dp[v] * sz[u];
ans+=(dp[u]+sz[u]∗dis)∗sz[v]+dp[v]∗sz[u];其中
s
z
sz
sz表示关键点的个数,
d
i
s
dis
dis表示深度差,就是距离。
关于最大和最小值。
同样
m
n
[
u
]
,
m
x
[
u
]
mn[u],mx[u]
mn[u],mx[u]表示子树中的最小最大值。
类似树的直接。然后套上虚树的板子就可以了。
#include "bits/stdc++.h"
using namespace std;
inline int read() {
int x = 0;
bool f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
if (f) return x;
return 0 - x;
}
#define SZ(x) ((int)x.size())
#define ll long long
const int maxn = 1000000 + 10;
const ll inf = 1e18;
struct edge {
int u, v, nxt;
} ed[maxn << 1];
int head[maxn << 1], cnt;
void add_e(int u, int v) {
ed[++cnt] = edge{u, v, head[u]};
head[u] = cnt;
}
int dep[maxn], fa[maxn][22], lg[maxn], dfn[maxn], id;
void dfs(int u, int f) {
fa[u][0] = f, dfn[u] = ++id;
dep[u] = dep[f] + 1;
for (int i = 1; i <= lg[dep[u]]; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = head[u]; i; i = ed[i].nxt)
if (ed[i].v != f)
dfs(ed[i].v, u);
}
int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (dep[x] > dep[y])
x = fa[x][lg[dep[x] - dep[y]] - 1];
if (x == y) return x;
for (int k = lg[dep[x]] - 1; k >= 0; k--)
if (fa[x][k] != fa[y][k])
x = fa[x][k], y = fa[y][k];
return fa[x][0];
}
vector<int> g[maxn];
int sta[maxn], top = 0;
bool cmp(int x, int y) { return dfn[x] < dfn[y]; }
int n, m, k, a[maxn], vis[maxn];
void modify(int x) {
if (top == 1) {
sta[++top] = x;
return;
}
int lca = LCA(x, sta[top]);
if (lca == sta[top]) {
sta[++top] = x;
return;
}
while (top > 1 && dfn[sta[top - 1]] >= dfn[lca])
g[sta[top - 1]].push_back(sta[top]), top--;
if (lca != sta[top]) g[lca].push_back(sta[top]), sta[top] = lca;
sta[++top] = x;
}
ll dp[maxn], mn[maxn], mx[maxn], sz[maxn];
ll ans, vmax, vmin;
void solve(int u) {
dp[u] = 0;
sz[u] = vis[u];
if (vis[u]) mn[u] = mx[u] = 0;
else mn[u] = inf, mx[u] = -inf;
for (int v:g[u]) {
solve(v);
ll dis = dep[v] - dep[u];
ans += (dp[u] + sz[u] * dis) * sz[v] + dp[v] * sz[u];
sz[u] += sz[v];
dp[u] += dp[v] + sz[v] * dis;
vmax = max(vmax, mx[u] + mx[v] + dis);
vmin = min(vmin, mn[u] + mn[v] + dis);
mx[u] = max(mx[u], mx[v] + dis);
mn[u] = min(mn[u], mn[v] + dis);
}
g[u].clear();
}
int main() {
n = read();
for (int i = 1; i <= n; i++)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for (int i = 1, u, v; i < n; i++) {
u = read(), v = read();
add_e(u, v);
add_e(v, u);
}
dfs(1, 0);
m = read();
while (m--) {
k = read();
for (int i = 1; i <= k; i++) {
a[i] = read();
vis[a[i]] = 1;
}
if (k == 1) {
printf("0 0 0\n");
} else {
sort(a + 1, a + k + 1, cmp);
sta[top = 1] = 1;
for (int i = 1; i <= k; i++) {
if (a[i] != 1) modify(a[i]);
}
while (top > 1) g[sta[top - 1]].push_back(sta[top]), top--;
ans = 0;
vmax = -inf;
vmin = inf;
solve(1);
printf("%lld %lld %lld\n", ans, vmin, vmax);
}
for (int i = 1; i <= k; i++) vis[a[i]] = 0;
}
return 0;
}