https://ac.nowcoder.com/acm/contest/15/C
很显然,虚树建出来后,dp一下,答案就是(虚树的直径+1)/ 2。。
但是wa了两个小时后才发现自己一直用的是假的虚树板子。。。
看了别人代码后,发现树的直径也可以通过离深度最大点的距离求出,并不需要建虚树dp。。。
。。。
虚树dp求直径
#include<bits/stdc++.h>
using namespace std;
vector<int>g[300010],gg[300010];
int fa[300010][25],h[300010],in[300010],out[300010],dfn;
void dfs(int u,int f){
int i,v;
fa[u][0]=f;
h[u]=h[f]+1;
in[u]=++dfn;
for(i=1;i<=20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(i=0;i<g[u].size();i++){
v=g[u][i];
if(v!=f) dfs(v,u);
}
out[u]=dfn;
}
int lca(int a,int b){
int i;
if(h[a]<h[b]) swap(a,b);
for(i=20;i>=0;i--) if(h[fa[a][i]]>=h[b]) a=fa[a][i];
if(a==b) return a;
for(i=20;i>=0;i--) if(fa[a][i]!=fa[b][i]) a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
int A[300010],vis[300010],st[300010],m,ans;
int ma1[300010],ma2[300010];
int cmp(int a,int b){
return in[a]<in[b];
}
void clear(){
int i;
for(i=1;i<=m;i++){
vis[A[i]]=0;
gg[A[i]].clear();
}
ans=0;
}
void build(){
int i,top=0,c=m;
sort(A+1,A+1+m,cmp);
for(i=1;i<c;i++) A[++m]=lca(A[i],A[i+1]);
sort(A+1,A+1+m,cmp);
m=unique(A+1,A+1+m)-A-1;
for(i=1;i<=m;i++){
if(top==0) st[++top]=A[i];
else{
while(out[st[top]]<in[A[i]]) top--;
gg[st[top]].push_back(A[i]);
st[++top]=A[i];
}
}
}
void dfs1(int u,int f){
int i,v,d;
ma1[u]=ma2[u]=-1;
if(vis[u]) ma1[u]=0;
for(i=0;i<gg[u].size();i++){
v=gg[u][i];
dfs1(v,u);
d=ma1[v]+h[v]-h[u];
if(ma1[u]<d){
ma2[u]=ma1[u];
ma1[u]=d;
}
else if(ma2[u]<d) ma2[u]=d;
}
}
void dfs2(int u,int f,int ma){
int d=h[u]-h[f],i,v,mx=-1;
if(ma==-1){
if(ma2[u]==-1) mx=ma1[u];
else mx=ma1[u]+ma2[u];
}
else mx=ma+d+ma1[u];
ans=max(mx,ans);
for(i=0;i<gg[u].size();i++){
v=gg[u][i];
d=ma1[v]+h[v]-h[u];
if(d!=ma1[u]) mx=ma1[u];
else mx=ma2[u];
if(ma!=-1) mx=max(ma+h[u]-h[f],mx);
dfs2(v,u,mx);
}
}
void solve(){
scanf("%d",&m);
for(int i=1;i<=m;i++){
scanf("%d",&A[i]);
vis[A[i]]=1;
}
build();
dfs1(A[1],0);
dfs2(A[1],0,-1);
printf("%d\n",(ans+1)/2);
clear();
}
int main(){
int i,n,q,a,b;
scanf("%d",&n);
for(i=1;i<n;i++){
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
dfs(1,0);
scanf("%d",&q);
while(q--) solve();
return 0;
}
非dp求直径代码
#include<bits/stdc++.h>
using namespace std;
vector<int>g[300010];
int fa[300010][25],h[300010],A[300010],dfn;
void dfs(int u,int f){
int i,v;
fa[u][0]=f;
h[u]=h[f]+1;
for(i=1;i<=20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(i=0;i<g[u].size();i++){
v=g[u][i];
if(v!=f) dfs(v,u);
}
}
int lca(int a,int b){
int i;
if(h[a]<h[b]) swap(a,b);
for(i=20;i>=0;i--) if(h[fa[a][i]]>=h[b]) a=fa[a][i];
if(a==b) return a;
for(i=20;i>=0;i--) if(fa[a][i]!=fa[b][i]) a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
int dis(int a,int b){
return h[a]+h[b]-h[lca(a,b)]*2;
}
int main(){
int i,n,m,q,a,b;
scanf("%d",&n);
for(i=1;i<n;i++){
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
dfs(1,0);
scanf("%d",&q);
while(q--){
scanf("%d",&m);
a=-1;
for(i=1;i<=m;i++){
scanf("%d",&A[i]);
if(h[A[i]]>a){
a=h[A[i]];
b=A[i];
}
}
a=-1;
for(i=1;i<=m;i++) a=max(a,dis(A[i],b));
printf("%d\n",(a+1)/2);
}
return 0;
}