题意:给一张 n n n 点 m m m 边的连通无向图, q q q 次询问,每次给出一个点集 S S S ,求有多少个不在 S S S 中的点满足删除后 S S S 中存在两个点不连通。
n ≤ 1 0 5 , m ≤ 2 × 1 0 5 , ∑ ∣ S ∣ ≤ 2 × 1 0 5 n\leq 10^5,m\leq 2\times 10^5,\sum |S|\leq 2\times 10^5 n≤105,m≤2×105,∑∣S∣≤2×105
显然是虚树
题目相当于求 S S S 的割点数量,想到圆方树
建出圆方树后,对于 S S S 中的一对点,它们路径上任意一个圆点都满足条件。答案相当于求两两路径上的圆点的并集。因为不能取 S S S 中的点,所以要减去 ∣ S ∣ |S| ∣S∣。
套到虚树上,发现就是虚树覆盖的圆点个数。即对于每个点 u u u ,设 f a u fa_u fau 为其虚树上的父亲, s u m u sum_u sumu 为原树上根到 u u u 的圆点个数。那么答案为 ∑ ( s u m u − s u m f a u ) \sum(sum_u-sum_{fa_u}) ∑(sumu−sumfau)。
注意要特判 lca ( S ) \operatorname{lca}(S) lca(S)
复杂度 O ( n + m + ∑ ∣ S ∣ log ∣ S ∣ ) O(n+m+\sum|S|\log |S|) O(n+m+∑∣S∣log∣S∣)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <algorithm>
#define MAXN 400005
#define MAXM 800005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
struct edge{int u,v;}e[MAXM];
int head[MAXN],nxt[MAXM],cnt;
inline void addnode(int u,int v)
{
e[++cnt]=(edge){u,v};
nxt[cnt]=head[u];
head[u]=cnt;
}
int n,m;
int dfn[MAXN],low[MAXN],tim;
int stk[MAXN],tp,vis[MAXN],bcc[MAXN],vcnt;
vector<int> rtt[MAXN];
void tarjan(int u)
{
dfn[u]=low[u]=++tim;
for (int i=head[u];i;i=nxt[i])
{
if (!vis[i>>1]&&!bcc[i>>1]) vis[(stk[++tp]=i)>>1]=1;
if (!dfn[e[i].v])
{
tarjan(e[i].v);
low[u]=min(low[u],low[e[i].v]);
if (dfn[u]==low[e[i].v])
{
rtt[u].push_back(bcc[i>>1]=++vcnt);
rtt[bcc[i>>1]].push_back(u);
while (vis[i>>1])
{
int t=stk[tp--];
vis[t>>1]=0;
rtt[bcc[t>>1]=vcnt].push_back(e[t].v);
}
}
}
else low[u]=min(low[u],dfn[e[i].v]);
}
}
int sum[MAXN],dep[MAXN],fa[MAXN][20];
void dfs(int u)
{
dfn[u]=++tim;
for (int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for (int i=0;i<(int)rtt[u].size();i++)
if (!dep[rtt[u][i]])
{
dep[rtt[u][i]]=dep[u]+1;
sum[rtt[u][i]]=sum[u]+(rtt[u][i]<=n);
fa[rtt[u][i]][0]=u;
dfs(rtt[u][i]);
}
}
inline int lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
int t=dep[x]-dep[y];
for (int i=0;(1<<i)<=t;i++) if (t&(1<<i)) x=fa[x][i];
if (x==y) return x;
for (int i=19;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int lis[MAXM],len;
inline bool cmp(const int& x,const int& y){return dfn[x]<dfn[y];}
inline void solve()
{
sort(lis+1,lis+len+1,cmp);
int res=len;
for (int i=1;i<res;i++) lis[++len]=lca(lis[i],lis[i+1]);
sort(lis+1,lis+len+1,cmp);
len=unique(lis+1,lis+len+1)-lis-1;
int ans=0;
tp=0;
for (int i=1;i<=len;i++)
{
while (tp&&lca(stk[tp],lis[i])!=stk[tp]) --tp;
ans+=(tp? sum[lis[i]]-sum[stk[tp]]:(lis[i]<=n));
stk[++tp]=lis[i];
}
printf("%d\n",ans-res);
}
int main()
{
for (int T=read();T;T--)
{
n=read(),m=read();
cnt=1;
memset(head,0,sizeof(head));
memset(nxt,0,sizeof(nxt));
memset(bcc,0,sizeof(bcc));
memset(dep,0,sizeof(dep));
memset(sum,0,sizeof(sum));
memset(dfn,0,sizeof(dfn));
for (int i=1;i<=vcnt;i++) rtt[i].clear();
tim=tp=0;
for (int i=1;i<=m;i++)
{
int u,v;
u=read(),v=read();
addnode(u,v),addnode(v,u);
}
vcnt=n;
tarjan(1);
for (int i=n+1;i<=vcnt;i++)
{
sort(rtt[i].begin(),rtt[i].end());
rtt[i].erase(unique(rtt[i].begin(),rtt[i].end()),rtt[i].end());
}
dep[1]=sum[1]=1,tim=0;
dfs(1);
for (int q=read();q;q--)
{
len=read();
for (int i=1;i<=len;i++) lis[i]=read();
solve();
}
}
return 0;
}