题意:
n个点m条边,每次给出一个点集,可以删掉一个非点集中的点,问有多少删法使得存在两个点集中的点不连通。
题解:
建出圆方树的虚树,显然删去上面的圆点是合法的。
直接搞就好了。
code:
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
struct node{
int y,next;
}a[400010];int len,last[200010];
int n,m,dfn[200010],low[200010],cnt,z,sta[200010],top,ys[400010];
int h[400010],tim=0;
vector <int> vec[200010];
struct trnode{
int dep,fa[22],c;
}tr[400010];
void ins(int x,int y)
{
a[++len].y=y;
a[len].next=last[x];last[x]=len;
}
void tarjan(int x,int fa)
{
dfn[x]=low[x]=++z;sta[++top]=x;
for(int i=last[x];i;i=a[i].next)
{
if(i==(fa^1)) continue;
int y=a[i].y;
if(!dfn[y])
{
tarjan(y,i);
if(low[y]>=dfn[x])
{
cnt++;
while(1)
{
int t=sta[top--];
vec[cnt].push_back(t);
if(t==y) break;
}
vec[cnt].push_back(x);
}
else low[x]=min(low[x],low[y]);
}
else low[x]=min(low[x],dfn[y]);
}
}
void pre(int x,int fa)
{
tr[x].fa[0]=fa;tr[x].dep=tr[fa].dep+1;
ys[x]=++z;tr[x].c=tr[fa].c+(x<=n);
for(int i=1;(1<<i)<=tr[x].dep;i++)
tr[x].fa[i]=tr[tr[x].fa[i-1]].fa[i-1];
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y==fa) continue;
pre(y,x);
}
}
int ans=0,tot[400010],num;
int findlca(int x,int y)
{
if(tr[x].dep<tr[y].dep) swap(x,y);
for(int i=20;i>=0;i--)
if((1<<i)<=tr[x].dep-tr[y].dep) x=tr[x].fa[i];
if(x==y) return x;
for(int i=20;i>=0;i--)
if((1<<i)<=tr[x].dep&&tr[x].fa[i]!=tr[y].fa[i])
x=tr[x].fa[i],y=tr[y].fa[i];
return tr[x].fa[0];
}
bool cmp(int a,int b) {return ys[a]<ys[b];}
void solve()
{
sort(h+1,h+num+1,cmp);
ans=top=0;sta[++top]=1;int k=0;
for(int i=1;i<=num;i++)
{
int lca=findlca(h[i],sta[top]);if(lca==1) k++;
if(lca==sta[top])
{
if(sta[top]!=h[i]) sta[++top]=h[i];
continue;
}
while(top>1&&tr[sta[top-1]].dep>=tr[lca].dep) ans+=tr[sta[top]].c-tr[sta[top-1]].c,top--;
if(lca!=sta[top]) ans+=tr[sta[top]].c-tr[lca].c,sta[top]=lca;
sta[++top]=h[i];
}
while(top>2) ans+=tr[sta[top]].c-tr[sta[top-1]].c,top--;
if(k>=2||h[1]==1) ans+=tr[sta[top]].c;
else if(top>1) ans+=(sta[top]<=n);
ans-=num;
}
int main()
{
int T;scanf("%d",&T);
while(T--)
{
scanf("%d %d",&n,&m);
len=1;memset(last,0,sizeof(last));
for(int i=1;i<=m;i++)
{
int x,y;scanf("%d %d",&x,&y);
ins(x,y);ins(y,x);
}
for(int i=1;i<=cnt;i++) vec[i].clear();
top=z=cnt=0;memset(dfn,0,sizeof(dfn));
tarjan(1,0);
len=0;memset(last,0,sizeof(last));
for(int i=1;i<=cnt;i++)
{
int x=i+n;
for(int j=0;j<vec[i].size();j++)
{
int y=vec[i][j];
ins(x,y);ins(y,x);
}
}
tr[0].dep=-1;z=0;pre(1,0);
int Q;scanf("%d",&Q);
while(Q--)
{
scanf("%d",&num);tim++;
for(int i=1;i<=num;i++) scanf("%d",&h[i]);
solve();printf("%d\n",ans);
}
}
}