题意:
给你一张无向连通图,有n个点m条边,边的编号按照输入顺序来排。有q次询问,每次询问给出两个点x和y,询问从x和y分别出发,一共经过了z个点经过的所有边的最大编号最小是多少。n,m,q<=1e5
题解:
这题是我做的第一道整体二分的题,所以用这道题为例做一下学习笔记。
首先对于这道题,我们看到最大值最小,不难想到一种单次询问用二分+并查集的做法,但是多次询问复杂度爆炸。但是我们仍然沿用二分答案这个想法,我们这里要用到一种整体二分的方法。
我按照我的理解写一下整体二分。可能会有诸多不严谨或者错误的地方,望诸位大神斧正。
整体二分是一种离线算法,它的思想是把所有询问离线下来,然后二分答案,每次对于当前二分的答案,我们把所有询问都带进去检验,看当前二分的答案对于每一组是否可行。我们记录哪些是当前可行的,哪些是当前答案不可行的,然后再分别递归到
[
l
,
m
i
d
]
[l,mid]
[l,mid]和
[
m
i
d
+
1
,
r
]
[mid+1,r]
[mid+1,r]检验可行性。
对于这道题,我们发现并查集的话在二分的答案变化后可能需要撤回,于是不能路径压缩,为了保证复杂度,我们采用启发式合并,这样保证树高是 l o g n logn logn的。我们对于当前二分的值,把编号在 [ l , m i d ] [l,mid] [l,mid]之间的边合并,并且为了撤回,记录下是从谁合并到了谁,然后更新 s i z e size size。接下来对于所有递归到这个区间的询问一一进行判断,看在每一个询问下一步递归应该递归到 [ l , m i d ] [l,mid] [l,mid]还是 [ m i d + 1 , r ] [mid+1,r] [mid+1,r]。判断完了之后要先撤回当前操作,因为你会先递归到左区间,左区间的答案会小于等于当前 m i d mid mid,所以现在需要暂时撤回。而我们递归到一个 l = r l=r l=r的区间时就意味着递归到这段区间的询问的答案就是 l l l(或者说是 r r r),这时由于我们先递归左区间,上面一层又撤回了,因为我们要在递归右区间时只会加当前区间的这些边,于是要保证之前编号更小的边已经加过了,那么我们就在递归到底的时候不会撤回地合并。
以上就是这题的做法和我理解的整体二分的思想了,下面是代码。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,qq,ans[100010],f[100010],sz[100010];
struct node
{
int x,y;
}a[100010];
struct qwq
{
int x,y,z,id;
}q[100010],ji[100010];
stack<node> sta;
inline int getr(int x)
{
if(x==f[x])
return x;
else
return getr(f[x]);
}
inline void solve(int l,int r,int x,int y)
{
if(l==r)
{
for(int i=x;i<=y;++i)
ans[q[i].id]=l;
int fx=getr(a[l].x),fy=getr(a[l].y);
if(sz[fx]>sz[fy])
swap(fx,fy);
if(fx!=fy)
{
f[fx]=fy;
sz[fy]+=sz[fx];
}
return;
}
int mid=(l+r)>>1;
for(int i=l;i<=mid;++i)
{
int fx=getr(a[i].x),fy=getr(a[i].y);
if(sz[fx]>sz[fy])
swap(fx,fy);
if(fx!=fy)
{
f[fx]=fy;
sz[fy]+=sz[fx];
sta.push((node){fx,fy});
}
}
int cnt1=x-1,cnt2=0;
for(int i=x;i<=y;++i)
{
int fx=getr(q[i].x),fy=getr(q[i].y),size;
if(fx==fy)
size=sz[fx];
else
size=sz[fx]+sz[fy];
if(size>=q[i].z)
q[++cnt1]=q[i];
else
ji[++cnt2]=q[i];
}
for(int i=1;i<=cnt2;++i)
q[cnt1+i]=ji[i];
while(!sta.empty())
{
node qwqqq=sta.top();
sta.pop();
f[qwqqq.x]=qwqqq.x;
sz[qwqqq.y]-=sz[qwqqq.x];
}
solve(l,mid,x,cnt1);
solve(mid+1,r,cnt1+1,y);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;++i)
scanf("%d%d",&a[i].x,&a[i].y);
for(int i=1;i<=n;++i)
{
f[i]=i;
sz[i]=1;
}
scanf("%d",&qq);
for(int i=1;i<=qq;++i)
{
scanf("%d%d%d",&q[i].x,&q[i].y,&q[i].z);
q[i].id=i;
}
solve(1,m,1,qq);
for(int i=1;i<=qq;++i)
printf("%d\n",ans[i]);
return 0;
}