[平衡树+启发式合并 || 点分治] POJ1741 Tree

题意

这题就是楼教主男人必做八题之一 给出一棵有边权的树,以及一个数K,求距离小于等于K的点对的个数。

题解

点分治显然可做,对于当前点分树,把所有点到当前根的距离排序后扫一下即可。
复杂度 O(nlog22n)
这里主要讲一下另一种不错的思路——平衡树+启发式合并。
具体做法:
随便找个根,然后递归下去从叶到根把子树不断合并。过程中,给每个子树建平衡树来存其中所有点到子树的根的距离,并在合并的时候计算两点在不同子树,且路径经过子树根的父节点的答案。
这里写图片描述
考虑以x的两个儿子son1, son2为根的两颗子树,进行启发式合并:
先把son1中的元素全部+w(x,son1),把son2中的元素全部+w(x,son2)。
然后把其中的节点数小的那棵子树中的元素(假设是son1)一个一个拿出来,并依次到son2中求一下对答案的贡献(平衡树询问rank)。
询问完后一个一个插入到son2中,合并完成。
要注意这里必须先查询完在插入,而不能求一个插一个,因为这样会计算到同一个子树中的点对。
平衡树中询问 O(log2n) ,启发式合并 O(nlog2n) 。总复杂度 O(nlog22n)
启发式合并的复杂度是否科学呢?
其实是很显然的。我们每次合并都是把小的往大的里一个一个插入。考虑某个节点,它每进行一次插入操作(从一棵树插入到另一棵)一次,所在的树的size至少乘2,所以一个节点最多移动 O(log2n) 次,总复杂度 O(nlog2n) 。很简单又很妙的东西。

感觉这种方法和点分治还是比较类似的,区别就在于没有改变树的深度,而是采用数据结构维护信息来保证复杂度。

平衡树+启发式合并代码:

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=10005, maxe=20005;
struct node{
    int key,fix,size,cnt,tag;
    node* ch[2];
    void maintain(){ size=cnt+ch[0]->size+ch[1]->size; }
    void plus(int val){ if(!size) return; key+=val; tag+=val; }
    void pushdown(){
        if(!tag) return;
        ch[0]->plus(tag); ch[1]->plus(tag);
        tag=0;
    }
} nil, *null, base[maxn*50], *len;
typedef node* P_node;
void Treap_init(){
    null=&nil;
    null->cnt=null->size=null->tag=0; null->fix=-1e+9;
    null->ch[0]=null->ch[1]=null;
    len=base;
}
P_node newnode(int tkey){
    len->key=tkey; len->fix=rand();
    len->size=len->cnt=1; len->tag=0;
    len->ch[0]=len->ch[1]=null;
    return len++;
}
void rot(P_node &p,int d){
    p->pushdown();
    P_node k=p->ch[d^1]; k->pushdown(); 
    p->ch[d^1]=k->ch[d]; k->ch[d]=p;
    p->maintain(); k->maintain(); p=k;
}
void Insert(P_node &p,int tkey){
    if(p==null) p=newnode(tkey); else
    if(p->key==tkey) p->cnt++; else{
        p->pushdown();
        int d=tkey>p->key;
        Insert(p->ch[d],tkey); if(p->ch[d]->fix>p->fix) rot(p,d^1);
    }
    p->maintain();
}
int Rank(P_node p,int tkey,int res){
    if(p->key==tkey||p==null) return p->ch[0]->size+res+1;
    p->pushdown();
    if(tkey<p->key) return Rank(p->ch[0],tkey,res);
    return Rank(p->ch[1],tkey,res+ p->ch[0]->size + p->cnt); 
}
int n,m,ans,pre[maxn],fir[maxn],nxt[maxe],son[maxe],w[maxe],tot,c[maxn];
void add(int x,int y,int z){
    son[++tot]=y; w[tot]=z; nxt[tot]=fir[x]; fir[x]=tot;
}
void Print(P_node p){
    if(p==null) return;
    p->pushdown();
    Print(p->ch[0]);
    for(int i=1;i<=p->cnt;i++) c[++c[0]]=p->key;
    Print(p->ch[1]);
}
P_node merge(P_node r1,P_node r2){
    if(r1->size<r2->size) swap(r1,r2);
    c[0]=0; Print(r2);
    for(int i=1;i<=c[0];i++) ans+=Rank(r1,m-c[i]+1,0)-1;
    for(int i=1;i<=c[0];i++) Insert(r1,c[i]);
    return r1;
}
P_node dfs(int x){
    P_node p=newnode(0);
    for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre[x]){
        pre[son[j]]=x; P_node now=dfs(son[j]);
        now->plus(w[j]); 
        p=merge(p,now);
    }
    return p;
}
int main(){
    freopen("poj1741.in","r",stdin);
    freopen("poj1741.out","w",stdout);
    for(scanf("%d%d",&n,&m);!(!n&&!m);scanf("%d%d",&n,&m)){
        Treap_init(); ans=0;
        memset(fir,0,sizeof(fir)); tot=0;
        for(int i=1;i<=n-1;i++){
            int x,y,z; scanf("%d%d%d",&x,&y,&z);
            add(x,y,z); add(y,x,z);
        }
        dfs(1);
        printf("%d\n",ans); 
    }
    return 0;
}

点分治代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=10005, maxe=20005;
int n,m,root,_min,ans,len,sz[maxn],fa[maxn],d[maxn],num[maxn];
int fir[maxn],nxt[maxe],son[maxe],w[maxe],pre[maxn],tot;
bool vis[maxn];
struct data{
    int x;
    bool operator < (const data &b)const{
        return d[x]<d[b.x];
    }
} a[maxn],c[maxn];
void add(int x,int y,int z){
    son[++tot]=y; w[tot]=z; nxt[tot]=fir[x]; fir[x]=tot;
}
void dfs(int x){
    sz[x]=1; a[++len].x=x;
    for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre[x]&&!vis[son[j]]){
        pre[son[j]]=x; d[son[j]]=d[x]+w[j]; 
        fa[son[j]]=pre[x]?fa[x]:son[j];
        dfs(son[j]);
        sz[x]+=sz[son[j]];
    }
}
void get_hvy(int x,int rt){
    int res=0;
    for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre[x]&&!vis[son[j]]) 
     get_hvy(son[j],rt), res=max(res,sz[son[j]]);
    if(pre[x]) res=max(res,sz[rt]-sz[x]);
    if(res<_min) _min=res, root=x;
}
void msort(int L,int R){
    if(L>=R) return;
    int mid=(L+R)>>1;
    msort(L,mid); msort(mid+1,R);
    int i=L,j=mid+1;
    for(int k=L;k<=R;k++) c[k]=a[k];
    for(int k=L;k<=R;k++) if(i<=mid&&(j>R||c[i]<c[j])) a[k]=c[i++]; 
                                                  else a[k]=c[j++];  
}
void get(int x){ 
    _min=1e+9; get_hvy(x,x); 
    fa[root]=pre[root]=d[root]=len=0; dfs(root);
    msort(1,len); 
    memset(num,0,sizeof(num));
    int now=0;
    for(int i=len;i>=1;i--){
        while(now<len&&d[a[now+1].x]<=m-d[a[i].x]) num[fa[a[++now].x]]++;
        ans+=now-num[fa[a[i].x]];
    }
    vis[root]=true;
    for(int j=fir[root];j;j=nxt[j]) if(!vis[son[j]]) get(son[j]);
}
int main(){
    freopen("poj1741.in","r",stdin);
    freopen("poj1741.out","w",stdout);
    while(scanf("%d%d",&n,&m)==2&&n){
        memset(fir,0,sizeof(fir)); tot=0;
        memset(vis,0,sizeof(vis));
        ans=0;
        for(int i=1;i<=n-1;i++){
            int x,y,z; scanf("%d%d%d",&x,&y,&z);
            add(x,y,z); add(y,x,z);
        }
        fa[1]=pre[1]=d[1]=len=0; dfs(1); 
        get(1);
        printf("%d\n",ans/2);   
    }
    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值