题目链接
题意
给你一颗树边权为1多次询问,每次询问若干个点,求任意两点之间距离之和,最远点对权值,最近点对权值
思路
询问很多,但是询问总共的点集不超过
2000000
2000000
2000000 可以对每次查询建虚树,得到一颗新树后树形dp。
d
p
[
u
]
[
0
]
dp[u][0]
dp[u][0] 表示 子树中的询问节点数量
d
p
[
u
]
[
1
]
dp[u][1]
dp[u][1] 表示 以u为根节点子树中询问节点到u的最近距离
d
p
[
u
]
[
1
]
dp[u][1]
dp[u][1] 表示 以u为根节点子树中询问节点到u的最远距离
代码
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <vector>
using namespace std;
#define ll long long
namespace shupou {
const ll N = 1000005;
const ll M = 1000005<<1;
ll first[N], tot;
struct Node {
ll v, w, nxt;
}e[M];
ll deep[N], f[N], sz[N], son[N];
ll cnt, dfn[N], top[N], w[N];
void init() {
memset(first,-1,sizeof(first));
w[1] = 1e18;
tot = 0;
cnt = 0;
deep[0] = 0;
}
void add(ll u, ll v, ll w) {
e[tot].v = v;
e[tot].w = w;
e[tot].nxt = first[u];
first[u] = tot++;
}
void dfs1(ll u, ll fa) {
deep[u] = deep[fa]+1;
f[u] = fa;
sz[u] = 1;
son[u] = 0;
ll maxson = -1;
for(ll i = first[u]; ~i; i = e[i].nxt) {
ll v = e[i].v;
if(v == fa) continue;
w[v] = w[u]+1;
dfs1(v,u);
sz[u] += sz[v];
if(sz[v] > maxson) son[u] = v, maxson = sz[v];
}
}
void dfs2(ll u, ll topfa) {
dfn[u] = ++cnt;
top[u] = topfa;
if(!son[u]) return;
dfs2(son[u],topfa);
for(ll i = first[u]; ~i; i = e[i].nxt) {
ll v = e[i].v;
if(v == son[u] || v == f[u]) continue;
dfs2(v,v);
}
}
ll getlca(ll x, ll y) {
while(top[x] != top[y]) {
if(deep[top[x]] < deep[top[y]]) swap(x,y);
x = f[top[x]];
}
return deep[x] > deep[y] ? y : x;
}
}
const ll N = 1000005;
bool cmp(ll a, ll b) {
return shupou::dfn[a] < shupou::dfn[b];
}
ll vis[N], sta[N], top;
vector<ll> e[N];
void push(ll x) {
int lc = shupou::getlca(x,sta[top]);
if(lc == sta[top]) {
sta[++top] = x;
return;
}
while(shupou::deep[lc] <= shupou::deep[sta[top-1]]) e[sta[top-1]].push_back(sta[top]), --top;
if(shupou::deep[lc] != shupou::deep[sta[top]]) e[lc].push_back(sta[top]), sta[top] = lc;
sta[++top] = x;
}
ll dp[N][4], sum, minans, maxans; // 0 子树节点数,1子树最近到该点,2子树最远到该点
ll mp[N], m;
void dfs(ll u) {
if(mp[u]) dp[u][0] = 1, dp[u][1] = 0, dp[u][2] = 0;
else dp[u][0] = 0, dp[u][1] = 1e9, dp[u][2] = -1e9;
for(ll i = 0; i < e[u].size(); ++i) {
ll v = e[u][i];
ll w = shupou::w[v]-shupou::w[u];
dfs(v);
sum += dp[v][0]*(m-dp[v][0])*w;
minans = min(minans, dp[u][1]+dp[v][1]+w);
maxans = max(maxans, dp[u][2]+dp[v][2]+w);
dp[u][2] = max(dp[u][2], dp[v][2]+w);
dp[u][1] = min(dp[u][1], dp[v][1]+w);
dp[u][0] += dp[v][0];
}
e[u].clear();
}
int main() {
shupou::init();
ll n;
scanf("%lld",&n);
for(ll i = 1; i < n; ++i) {
ll u, v;
scanf("%lld%lld",&u,&v);
shupou::add(u,v,1);
shupou::add(v,u,1);
}
shupou::dfs1(1,0);
shupou::dfs2(1,1);
ll t;
for(scanf("%lld",&t); t; --t) {
scanf("%lld",&m);
for(ll i = 1; i <= m; ++i) scanf("%lld",&vis[i]), mp[vis[i]] = 1;
sort(vis+1,vis+1+m,cmp);
sta[top = 1] = 1;
if(vis[1] != 1) push(vis[1]);
for(ll i = 2; i <= m; ++i) push(vis[i]);
while(top > 1) e[sta[top - 1]].push_back(sta[top]), --top;
minans = 1e9, maxans = -1e9, sum = 0;
dfs(1);
printf("%lld %lld %lld\n",sum,minans,maxans);
for(ll i = 1; i <= m; ++i) mp[vis[i]] = 0;
}
return 0;
}