【题目链接】
【前置技能】
- 虚树
- 树形DP
【题解】
- 关键点的数量和与 N N 同阶,那么建虚树,剩下来的问题就在于树形DP了。
- 树形DP记录四个值:一个端点为另一个端点在以 pos p o s 为根的子树中的路径的最大长度 maxn m a x n 、最小长度 minn m i n n 、数量 tot t o t 、长度和 sum s u m 。转移要分 pos p o s 是关键点和 pos p o s 非关键点两种情况。具体写法参见代码。
- 时间复杂度 O(NlogN+∑KlogN) O ( N l o g N + ∑ K l o g N )
【代码】
#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define LL long long
#define MAXN 1000010
#define MAXLOG 22
using namespace std;
int n, k, q[MAXN], used[MAXN], cnt, mark[MAXN], ansmin, ansmax;
int dep[MAXN], fa[MAXN][MAXLOG + 2], dfn[MAXN], root, tim;
int tot[MAXN], maxn[MAXN], minn[MAXN];
LL sum[MAXN], ans;
vector <int> a[MAXN], b[MAXN];
template <typename T> void chkmin(T &x, T y){x = min(x, y);}
template <typename T> void chkmax(T &x, T y){x = max(x, y);}
template <typename T> void read(T &x){
x = 0; int f = 1; char ch = getchar();
while (!isdigit(ch)) {if (ch == '-') f = -1; ch = getchar();}
while (isdigit(ch)) {x = x * 10 + ch - '0'; ch = getchar();}
x *= f;
}
void dfs(int pos, int dad){
dfn[pos] = ++tim;
dep[pos] = dep[dad] + 1;
fa[pos][0] = dad;
for (int i = 1; (1 << i) <= dep[pos]; ++i)
fa[pos][i] = fa[fa[pos][i - 1]][i - 1];
for (unsigned i = 0, si = a[pos].size(); i < si; ++i){
int son = a[pos][i];
if (son != dad){
dfs(son, pos);
}
}
}
int lca(int u, int v){
if (dep[u] > dep[v]) swap(u, v);
for (int i = MAXLOG; i >= 0; --i)
if (dep[fa[v][i]] >= dep[u]) v = fa[v][i];
if (u == v) return u;
for (int i = MAXLOG; i >= 0; --i)
if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
void add(int dad, int x){
b[dad].push_back(x);
}
bool cmp(int a, int b){
return dfn[a] < dfn[b];
}
void build(){
static int sta[MAXN], top;
sort(q + 1, q + 1 + k, cmp);
root = lca(q[1], q[k]);
int beg = 1 + (q[1] == root);
sta[top = 1] = used[cnt = 1] = root;
for (int i = beg; i <= k; ++i){
int low = lca(q[i], sta[top]);
if (low == sta[top]) sta[++top] = q[i];
else {
while (dfn[low] < dfn[sta[top - 1]]){
add(sta[top - 1], sta[top]);
used[++cnt] = sta[top--];
}
if (dfn[low] == dfn[sta[top - 1]]){
add(sta[top - 1], sta[top]);
used[++cnt] = sta[top--];
sta[++top] = q[i];
} else {
add(low, sta[top]);
used[++cnt] = sta[top--];
sta[++top] = low;
sta[++top] = q[i];
}
}
}
while (top >= 2){
add(sta[top - 1], sta[top]);
used[++cnt] = sta[top--];
}
}
void getans(int pos){
tot[pos] = sum[pos] = 0;
if (mark[pos]){
minn[pos] = maxn[pos] = 0;
int MAXI = 0, MAXII = 0;
for (unsigned i = 0, si = b[pos].size(); i < si; ++i){
int son = b[pos][i], wth = dep[son] - dep[pos];
getans(son);
ans += 1ll * sum[pos] * tot[son] + (sum[son] + 1ll * tot[son] * wth) * 1ll * tot[pos];
sum[pos] += sum[son] + 1ll * tot[son] * wth;
tot[pos] += tot[son];
int tmp = minn[son] + wth;
chkmin(ansmin, tmp);
tmp = maxn[son] + wth;
chkmax(maxn[pos], tmp);
if (tmp > MAXI) MAXII = MAXI, MAXI = tmp;
else chkmax(MAXII, tmp);
}
++tot[pos];
ans += sum[pos];
chkmax(ansmax, MAXI + MAXII);
} else {
minn[pos] = INF, maxn[pos] = 0;
int MAXI = 0, MAXII = 0, MINI = INF, MINII = INF;
for (unsigned i = 0, si = b[pos].size(); i < si; ++i){
int son = b[pos][i], wth = dep[son] - dep[pos];
getans(son);
ans += 1ll * sum[pos] * tot[son] + (sum[son] + 1ll * tot[son] * wth) * 1ll * tot[pos];
sum[pos] += sum[son] + 1ll * tot[son] * wth;
tot[pos] += tot[son];
int tmp = maxn[son] + wth;
chkmax(maxn[pos], tmp);
if (tmp > MAXI) MAXII = MAXI, MAXI = tmp;
else chkmax(MAXII, tmp);
tmp = minn[son] + wth;
chkmin(minn[pos], tmp);
if (tmp < MINI) MINII = MINI, MINI = tmp;
else chkmin(MINII, tmp);
}
chkmin(ansmin, MINI + MINII);
chkmax(ansmax, MAXI + MAXII);
}
}
int main(){
read(n);
for (int i = 1; i < n; ++i){
int u, v; read(u), read(v);
a[u].push_back(v); a[v].push_back(u);
}
dfs(1, 0);
int Q; read(Q);
while (Q--){
read(k);
for (int i = 1; i <= k; ++i)
read(q[i]), mark[q[i]] = 1;
build();
ans = ansmax = 0;
ansmin = INF;
getans(root);
printf("%lld %d %d\n", ans, ansmin, ansmax);
for (int i = 1; i <= k; ++i)
mark[q[i]] = 0;
for (int i = 1; i <= cnt; ++i)
b[used[i]].clear();
}
return 0;
}