POJ 2985 The k-th LargestGroup(Treap+并查集)
http://poj.org/problem?id=2985
题意:
有N只猫,开始每只猫都是一个小组,下面要执行M个操作,操作0 i j 是把i猫和j猫所属的小组合并,操作1 k 是问你当前第k大的小组大小是多少. 且k<=当前的最大组数.
分析:
用并查集维护每只猫所属的集合,且维护集合中的节点数.
然后假设集合S1和集合S2合并,就在Treap中删除V=|S1|的节点并删除V=|S2|的节点.并插入V=|S1|+|S2|的节点.(注意如果|S1|==1,不用删除,因为Treap只保存了合并之后的组大小)
然后查询的时候就查询Treap中第K大的节点v值即可.这里要注意,如果K>Treap中的节点总数,就默认输出1.(因为我们Treap只保存了合并之后的组大小)
AC代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
using namespace std;
struct Node
{
Node *ch[2];
int r,v,s;
Node(int v):v(v)
{
r=rand();
s=1;
ch[0]=ch[1]=NULL;
}
void maintain()
{
s=1;
if(ch[0]) s+=ch[0]->s;
if(ch[1]) s+=ch[1]->s;
}
int cmp(int x)
{
if(x==v)return -1;
return x<v?0:1;
}
}*root;
void rotate(Node* &o,int d)
{
Node *k=o->ch[d^1];
o->ch[d^1]=k->ch[d];
k->ch[d]=o;
o->maintain();
k->maintain();
o=k;
}
void insert(Node* &o,int v)//可插入重复值
{
if(o==NULL) o=new Node(v);
else
{
int d=v < o->v? 0:1;
insert(o->ch[d],v);
if(o->ch[d]->r > o->r)
rotate(o,d^1);
}
o->maintain();
}
void remove(Node* &o,int v)
{
int d=o->cmp(v);
if(d==-1)
{
Node *u=o;
if(o->ch[0] && o->ch[1])
{
int d2= o->ch[0]->r < o->ch[1]->r ?0:1;
rotate(o,d2);
remove(o->ch[d2],v);
}
else
{
if(o->ch[0]==NULL) o=o->ch[1];
else o=o->ch[0];
delete u;
}
}
else remove(o->ch[d],v);
if(o) o->maintain();
}
int kth(Node *o,int k)//返回第k大的值,不是第k小
{
if(o==NULL || k<=0 || k> o->s) return 1;
int s = (o->ch[1]==NULL)?0:o->ch[1]->s;
if(k==s+1) return o->v;
else if(k<=s) return kth(o->ch[1],k);
else return kth(o->ch[0],k-s-1);
}
const int maxn=200000+1000;
int n,m;
int F[maxn],size[maxn];
int findset(int i)
{
if(F[i]==-1) return i;
return F[i]=findset(F[i]);
}
void bind(int i,int j)
{
int fa=findset(i);
int fb=findset(j);
if(fa!=fb)
{
if(size[fa]!=1)remove(root,size[fa]);
if(size[fb]!=1)remove(root,size[fb]);
insert(root,size[fa]+size[fb]);
F[fa]=fb;
size[fb] += size[fa];
}
}
int main()
{
root=NULL;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
F[i]=-1;
size[i]=1;
}
while(m--)
{
int op,i,j,k;
scanf("%d",&op);
if(op==0)
{
scanf("%d%d",&i,&j);
bind(i,j);
}
else if(op==1)
{
scanf("%d",&k);
printf("%d\n",kth(root,k));
}
}
return 0;
}