题目大意
给定一个n个点的环套树,每条边有开、关两个状态。接下来m次操作,每次给定一对数(x,y),对于x到y的最短路径(如果有多条就选择经过点字典序的一条),路径上经过所有边的状态取反(开变关,关变开)。每次操作后输出只考虑开状态的边时,有多少个联通块。
n,m≤100000
分析
考虑在树上怎么做。
最短路径只有一条,那么直接可以确定。
对于当前的树,它的联通块个数就是n-状态为开的边的个数(连接一条边相当于合并两个联通块)。
现在变成了环套树,如果连接一条边,两个端点是来自不同联通块的,那么答案也是-1,否则答案不变。
由于只有一个环,当环上所有边都被选时才会出现答案不变的情况。那么当结论改成(答案为n-状态为开的边的个数+[环上的边都被选])是一样的。
由于是环套树,最短路径最多只有两条,判断的只是走到环上时,第一个不同的点而已。这个很好判断。
剩下的就是修改边的状态了,树链剖分解决!
注意x=y的情况。
时间复杂度 O(nlog2n)
我的代码很长很丑
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=100005,M=200005,T=262222,Log=17;
typedef long long LL;
int n,m,h[N],e[M],nxt[M],ans,tot,s,H[N],E[M],Nxt[M],a[N],b[N],size[N],dfn[N],seq[N],top[N],r,sum[T],id[N];
int t1[T],t2[T],tag1[T],tag2[T],fa[N][Log],dep[N];
bool v[N];
char c;
int read()
{
for (c=getchar();c<'0' || c>'9';c=getchar());
int x=c-48;
for (c=getchar();c>='0' && c<='9';c=getchar()) x=x*10+c-48;
return x;
}
void add(int x,int y)
{
e[++tot]=y; nxt[tot]=h[x]; h[x]=tot;
}
void Add(int x,int y)
{
E[++tot]=y; Nxt[tot]=H[x]; H[x]=tot;
}
bool init(int x,int y)
{
v[x]=1; b[++tot]=x;
for (int i=h[x];i;i=nxt[i]) if (e[i]!=y)
{
if (!v[e[i]])
{
if (init(e[i],x)) return 1;
}else
{
for (int j=tot;b[j]!=e[i];j--) a[++s]=b[j];
a[++s]=e[i];
return 1;
}
}
tot--; v[x]=0;
return 0;
}
void build(int x,int ID)
{
v[x]=1; dep[x]=dep[fa[x][0]]+1; id[x]=ID;
for (int i=h[x];i;i=nxt[i]) if (!v[e[i]])
{
Add(x,e[i]); fa[e[i]][0]=x;
build(e[i],ID);
size[x]=size[e[i]]+1;
}
}
void ins(int l,int r,int v,int x)
{
sum[x]++;
if (l==r) return;
int mid=l+r>>1;
if (v<=mid) ins(l,mid,v,x<<1);else ins(mid+1,r,v,x<<1|1);
}
void dfs(int x,int y)
{
dfn[x]=++tot; seq[tot]=x; top[tot]=r;
ins(1,n,tot,1);
int i,j=0;
for (i=H[x];i;i=Nxt[i]) if (E[i]!=y && (!j || size[E[i]]>size[j])) j=E[i];
if (!j) return;
dfs(j,x);
for (i=H[x];i;i=Nxt[i]) if (E[i]!=y && E[i]!=j)
{
r=tot+1;
dfs(E[i],x);
}
}
void change1(int l,int r,int a,int b,int x)
{
if (l==a && r==b)
{
t1[x]=sum[x]-t1[x];
tag1[x]^=1;
return;
}
int mid=l+r>>1;
if (tag1[x])
{
tag1[x]=0;
tag1[x<<1]^=1; t1[x<<1]=sum[x<<1]-t1[x<<1];
tag1[x<<1|1]^=1; t1[x<<1|1]=sum[x<<1|1]-t1[x<<1|1];
}
if (b<=mid) change1(l,mid,a,b,x<<1);
else if (a>mid) change1(mid+1,r,a,b,x<<1|1);
else
{
change1(l,mid,a,mid,x<<1); change1(mid+1,r,mid+1,b,x<<1|1);
}
t1[x]=t1[x<<1]+t1[x<<1|1];
}
void change2(int l,int r,int a,int b,int x)
{
if (l==a && r==b)
{
t2[x]=r-l+1-t2[x];
tag2[x]^=1;
return;
}
int mid=l+r>>1;
if (tag2[x])
{
tag2[x]=0;
tag2[x<<1]^=1; t2[x<<1]=mid-l+1-t2[x<<1];
tag2[x<<1|1]^=1; t2[x<<1|1]=r-mid-t2[x<<1|1];
}
if (b<=mid) change2(l,mid,a,b,x<<1);
else if (a>mid) change2(mid+1,r,a,b,x<<1|1);
else
{
change2(l,mid,a,mid,x<<1); change2(mid+1,r,mid+1,b,x<<1|1);
}
t2[x]=t2[x<<1]+t2[x<<1|1];
}
int getlca(int x,int y)
{
if (dep[x]<dep[y]) x^=y^=x^=y;
for (int i=Log-1;i>=0;i--) if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
for (int i=Log-1;i>=0;i--) if (fa[x][i]!=fa[y][i])
{
x=fa[x][i]; y=fa[y][i];
}
if (x!=y) x=fa[x][0];
return x;
}
void go(int x,int y)
{
if (x==y) return;
x=dfn[x];
int k=top[x];
if (k<=dfn[y]) k=dfn[y]+1;
change1(1,n,k,x,1);
go(fa[seq[k]][0],y);
}
int main()
{
n=read(); m=read();
for (int i=0;i<n;i++)
{
int x=read(),y=read();
add(x,y); add(y,x);
}
tot=0;
init(1,0);
memset(v,0,sizeof(v));
for (int i=1;i<=s;i++) v[a[i]]=1;
tot=0; n++; dep[n]=1;
for (int i=1;i<=s;i++)
{
id[a[i]]=i;
for (int j=h[a[i]];j;j=nxt[j]) if (!v[e[j]])
{
Add(n,e[j]); fa[e[j]][0]=n;
build(e[j],i);
}
}
r=1; tot=0;
dfs(n,0);
for (int j=1;j<Log;j++)
{
for (int i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
}
ans=n-1;
while (m--)
{
int x=read(),y=read(),lca;
if (x==y)
{
printf("%d\n",ans); continue;
}
if (a[id[x]]!=x && a[id[y]]!=y) lca=getlca(x,y);else lca=n;
if (a[id[x]]!=x) go(x,lca);
if (a[id[y]]!=y) go(y,lca);
if (lca==n)
{
int p=id[x],q=id[y],pre=(p>1)?p-1:s,nxt=(p<s)?p+1:1;
if (p<q)
{
if (q-p<s/2.0 || q-p==s/2.0 && a[nxt]<a[pre]) change2(1,s,p,q-1,1);
else
{
if (p>1) change2(1,s,1,p-1,1);
change2(1,s,q,s,1);
}
}else if (p>q)
{
if (p-q<s/2.0 || p-q==s/2.0 && a[nxt]>a[pre]) change2(1,s,q,p-1,1);
else
{
change2(1,s,p,s,1);
if (q>1) change2(1,s,1,q-1,1);
}
}
}
ans=n-t1[1]-t2[1]-1;
if (t2[1]==s) ans++;
printf("%d\n",ans);
}
return 0;
}