坑了四天的虚树,终于过了,也对treeDP印象深刻了些,这里挖个坑(基环树DP,该搞搞了
坑点1:两个点直接要的不仅仅的路径长度,而是路径上的所有点,那么应该是儿子数相减
坑点2:两个点在st表查找的时候可能会超边界,因为mid不一定在链上
坑点3:找某个点属于哪个点归属的时候,更新必须加入队列中,不然其后面的点不一定会改变值。
#include<bits/stdc++.h>
using namespace std;
const int maxn=3e5+7;
typedef long long ll;
struct node{
int to,next;
};
node edge[maxn+maxn];
int cnt,head[maxn],num[maxn],pos,son[maxn],deep[maxn];
void add(int x,int y){
cnt++;
edge[cnt].to=y;
edge[cnt].next=head[x];
head[x]=cnt;
}
void init(){
cnt=pos=0;
memset(son,-1,sizeof(son));
memset(head,-1,sizeof(head));
}
int ff[maxn][21],fa[maxn];
void dfs(int u,int pre,int w){
ff[u][0]=pre;
for(int i=1;i<=20;i++)ff[u][i]=ff[ff[u][i-1]][i-1];
deep[u]=w;num[u]=1;fa[u]=pre;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(v==pre)continue;
dfs(v,u,w+1);
num[u]+=num[v];
if(son[u]==-1||num[v]>num[son[u]])
son[u]=v;
}
}
int p[maxn],tp[maxn],fp[maxn];
void dfs2(int u,int sd){
p[u]=++pos;fp[pos]=u;tp[u]=sd;
if(son[u]==-1)return ;
dfs2(son[u],sd);
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(v==son[u]||v==fa[u])continue;
dfs2(v,v);
}
}
int LCA(int u,int v){
int uu=tp[u],vv=tp[v];
while(uu!=vv){
if(deep[vv]>deep[uu])swap(uu,vv),swap(u,v);
u=fa[u];uu=tp[u];
}
if(deep[u]>deep[v])swap(u,v);
return u;
}
int A[maxn],B[maxn],s[maxn],top;
bool vis1[maxn],vis2[maxn];
vector<int>v[maxn];
vector<int>vv;
int Size[maxn];
struct ttt{
int num,len;
};
ttt G1[maxn];
int lenn(int x,int y){
int lca=LCA(x,y);
return deep[x]+deep[y]-2*deep[lca];
}
int cmp1(int x,int y){
return p[x]<p[y];
}
int root;
int find1(int u,int x){
if(x==0)return u;
for(int i=0;i<=20;i++)if(x&(1<<i))u=ff[u][i];
return u;
}
void insert(int x){
if(!top){s[++top]=x;return ;}
int lca=LCA(x,s[top]);
while(top>1&&p[s[top-1]]>p[lca]){
v[s[top-1]].push_back(s[top]);
v[s[top]].push_back(s[top-1]);
top--;
}
if(p[lca]<p[s[top]]){
if(vis1[lca]==0){
vis1[lca]=1;
if(deep[lca]<deep[root])
root=lca;
vv.push_back(lca);
}
v[s[top]].push_back(lca);
v[lca].push_back(s[top]);top--;
}
if(!top||p[lca]>p[s[top]])s[++top]=lca;
s[++top]=x;
}
bool in1[maxn];//判断这个点是否已经在队列中了
void DP(){
queue<int>H;
for(int i=0;i<vv.size();i++){
if(vis2[vv[i]]){
G1[vv[i]].num=vv[i];
G1[vv[i]].len=0;
continue;
}
G1[vv[i]].len=1000006;
}
int u1,u2,len2;
for(int i=0;i<vv.size();i++){
if(vis2[vv[i]]){
H.push(vv[i]);
}
while(!H.empty()){
u1=H.front();H.pop();
for(int i=0;i<v[u1].size();i++){
u2=v[u1][i];len2=lenn(u2,u1);
if(G1[u1].len+len2<G1[u2].len){
G1[u2].num=G1[u1].num;
G1[u2].len=G1[u1].len+len2;
H.push(u2);
}else if(G1[u1].len+len2==G1[u2].len){
G1[u2].num=min(G1[u1].num,G1[u2].num);
H.push(u2);
}
}
}
}
}
int dfs3(int x,int pre){
if(v[x].size()==1&&pre!=-1){
Size[x]+=num[x]-1;
}else{
int u1,sum1,len1,len2,len3,len4,son1;
sum1=0;
for(int i=0;i<v[x].size();i++){
u1=v[x][i];
if(u1==pre)continue;
dfs3(u1,x);
len1=lenn(x,u1);
len1--;
len4=len1;
son1=find1(u1,len1);
sum1+=num[son1];
if(len1==0)continue;
if(G1[u1].num==G1[x].num){
// cout <<x <<"to" << u1 <<"same " <<Size[G1[u1].num] <<" +" <<num[son1]-num[u1]<<endl;
// cout << G1[x].num <<" "<<G1[u1].num << endl;
Size[G1[u1].num]+=num[son1]-num[u1];
}else{
len2=G1[x].len;len3=G1[u1].len;
// cout <<" diff " << len1 <<" " << len2 <<" " << len3 << endl;
if(len1+len2<=len3){
Size[G1[x].num]+=num[son1]-num[u1];
continue;
}else if(len3+len1<=len2){
Size[G1[u1].num]+=num[son1]-num[u1];
continue;
}
len1=len1+len2+len3;
if(len1%2==1&&G1[u1].num<G1[x].num){
len1/=2;len1++;
}else{
len1/=2;
}
len1-=G1[u1].len;
int k=find1(u1,len1);
// cout << u1 <<" up "<< len1 << endl;
Size[G1[u1].num]+=num[k]-num[u1]; //
Size[G1[x].num]+=num[son1]-num[k];
// cout << "find good " << G1[u1].num << endl;
}
}
int count1;
if(root == x)count1=num[1];
else count1=num[x];
Size[G1[x].num]+=count1-sum1-1; //自己已经在里面了
// cout<<x <<" else ++ " << count1-sum1-1 << " "<<Size[G1[x].num] <<" "<<sum1<<endl;
if(vis2[x]==0)
Size[G1[x].num]++;
}
}
int main(){
int i,j,k,f1,f2,f3,f4,t1,t2,t3,t4,n,m;
int K;
//freopen("in.txt","r",stdin);
//freopen("out2.txt","w",stdout);
scanf("%d",&n);
init();
for(i=1;i<n;i++){
scanf("%d %d",&t1,&t2);
add(t1,t2);add(t2,t1);
}
dfs(1,0,1);
dfs2(1,1);
scanf("%d",&m);
for(i=1;i<=m;i++){
scanf("%d",&K);
//memset(G1,0,sizeof(G1));
for(j=1;j<=K;j++){
scanf("%d",&A[j]);
B[j]=A[j];
Size[A[j]]=vis2[A[j]]=vis1[A[j]]=1;
vv.push_back(A[j]);
}
top=0;
sort(A+1,A+1+K,cmp1);
root=A[1];
for(j=1;j<=K;j++)insert(A[j]);
while(top>1){
v[s[top]].push_back(s[top-1]);
v[s[top-1]].push_back(s[top]);top--;
}
DP();
dfs3(root,-1);
for(j=1;j<=K;j++)
printf("%d ",Size[B[j]]);
printf("\n");
for(j=0;j<vv.size();j++){
v[vv[j]].clear();
Size[vv[j]]=vis1[vv[j]]=vis2[vv[j]]=0;
}
vv.clear();
}
return 0;
}