题意
这题就是楼教主男人必做八题之一 给出一棵有边权的树,以及一个数K,求距离小于等于K的点对的个数。
题解
点分治显然可做,对于当前点分树,把所有点到当前根的距离排序后扫一下即可。
复杂度
O(nlog22n)
。
这里主要讲一下另一种不错的思路——平衡树+启发式合并。
具体做法:
随便找个根,然后递归下去从叶到根把子树不断合并。过程中,给每个子树建平衡树来存其中所有点到子树的根的距离,并在合并的时候计算两点在不同子树,且路径经过子树根的父节点的答案。
考虑以x的两个儿子son1, son2为根的两颗子树,进行启发式合并:
先把son1中的元素全部+w(x,son1),把son2中的元素全部+w(x,son2)。
然后把其中的节点数小的那棵子树中的元素(假设是son1)一个一个拿出来,并依次到son2中求一下对答案的贡献(平衡树询问rank)。
询问完后一个一个插入到son2中,合并完成。
要注意这里必须先查询完在插入,而不能求一个插一个,因为这样会计算到同一个子树中的点对。
平衡树中询问
O(log2n)
,启发式合并
O(nlog2n)
。总复杂度
O(nlog22n)
。
启发式合并的复杂度是否科学呢?
其实是很显然的。我们每次合并都是把小的往大的里一个一个插入。考虑某个节点,它每进行一次插入操作(从一棵树插入到另一棵)一次,所在的树的size至少乘2,所以一个节点最多移动
O(log2n)
次,总复杂度
O(nlog2n)
。很简单又很妙的东西。
感觉这种方法和点分治还是比较类似的,区别就在于没有改变树的深度,而是采用数据结构维护信息来保证复杂度。
平衡树+启发式合并代码:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=10005, maxe=20005;
struct node{
int key,fix,size,cnt,tag;
node* ch[2];
void maintain(){ size=cnt+ch[0]->size+ch[1]->size; }
void plus(int val){ if(!size) return; key+=val; tag+=val; }
void pushdown(){
if(!tag) return;
ch[0]->plus(tag); ch[1]->plus(tag);
tag=0;
}
} nil, *null, base[maxn*50], *len;
typedef node* P_node;
void Treap_init(){
null=&nil;
null->cnt=null->size=null->tag=0; null->fix=-1e+9;
null->ch[0]=null->ch[1]=null;
len=base;
}
P_node newnode(int tkey){
len->key=tkey; len->fix=rand();
len->size=len->cnt=1; len->tag=0;
len->ch[0]=len->ch[1]=null;
return len++;
}
void rot(P_node &p,int d){
p->pushdown();
P_node k=p->ch[d^1]; k->pushdown();
p->ch[d^1]=k->ch[d]; k->ch[d]=p;
p->maintain(); k->maintain(); p=k;
}
void Insert(P_node &p,int tkey){
if(p==null) p=newnode(tkey); else
if(p->key==tkey) p->cnt++; else{
p->pushdown();
int d=tkey>p->key;
Insert(p->ch[d],tkey); if(p->ch[d]->fix>p->fix) rot(p,d^1);
}
p->maintain();
}
int Rank(P_node p,int tkey,int res){
if(p->key==tkey||p==null) return p->ch[0]->size+res+1;
p->pushdown();
if(tkey<p->key) return Rank(p->ch[0],tkey,res);
return Rank(p->ch[1],tkey,res+ p->ch[0]->size + p->cnt);
}
int n,m,ans,pre[maxn],fir[maxn],nxt[maxe],son[maxe],w[maxe],tot,c[maxn];
void add(int x,int y,int z){
son[++tot]=y; w[tot]=z; nxt[tot]=fir[x]; fir[x]=tot;
}
void Print(P_node p){
if(p==null) return;
p->pushdown();
Print(p->ch[0]);
for(int i=1;i<=p->cnt;i++) c[++c[0]]=p->key;
Print(p->ch[1]);
}
P_node merge(P_node r1,P_node r2){
if(r1->size<r2->size) swap(r1,r2);
c[0]=0; Print(r2);
for(int i=1;i<=c[0];i++) ans+=Rank(r1,m-c[i]+1,0)-1;
for(int i=1;i<=c[0];i++) Insert(r1,c[i]);
return r1;
}
P_node dfs(int x){
P_node p=newnode(0);
for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre[x]){
pre[son[j]]=x; P_node now=dfs(son[j]);
now->plus(w[j]);
p=merge(p,now);
}
return p;
}
int main(){
freopen("poj1741.in","r",stdin);
freopen("poj1741.out","w",stdout);
for(scanf("%d%d",&n,&m);!(!n&&!m);scanf("%d%d",&n,&m)){
Treap_init(); ans=0;
memset(fir,0,sizeof(fir)); tot=0;
for(int i=1;i<=n-1;i++){
int x,y,z; scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
dfs(1);
printf("%d\n",ans);
}
return 0;
}
点分治代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=10005, maxe=20005;
int n,m,root,_min,ans,len,sz[maxn],fa[maxn],d[maxn],num[maxn];
int fir[maxn],nxt[maxe],son[maxe],w[maxe],pre[maxn],tot;
bool vis[maxn];
struct data{
int x;
bool operator < (const data &b)const{
return d[x]<d[b.x];
}
} a[maxn],c[maxn];
void add(int x,int y,int z){
son[++tot]=y; w[tot]=z; nxt[tot]=fir[x]; fir[x]=tot;
}
void dfs(int x){
sz[x]=1; a[++len].x=x;
for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre[x]&&!vis[son[j]]){
pre[son[j]]=x; d[son[j]]=d[x]+w[j];
fa[son[j]]=pre[x]?fa[x]:son[j];
dfs(son[j]);
sz[x]+=sz[son[j]];
}
}
void get_hvy(int x,int rt){
int res=0;
for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre[x]&&!vis[son[j]])
get_hvy(son[j],rt), res=max(res,sz[son[j]]);
if(pre[x]) res=max(res,sz[rt]-sz[x]);
if(res<_min) _min=res, root=x;
}
void msort(int L,int R){
if(L>=R) return;
int mid=(L+R)>>1;
msort(L,mid); msort(mid+1,R);
int i=L,j=mid+1;
for(int k=L;k<=R;k++) c[k]=a[k];
for(int k=L;k<=R;k++) if(i<=mid&&(j>R||c[i]<c[j])) a[k]=c[i++];
else a[k]=c[j++];
}
void get(int x){
_min=1e+9; get_hvy(x,x);
fa[root]=pre[root]=d[root]=len=0; dfs(root);
msort(1,len);
memset(num,0,sizeof(num));
int now=0;
for(int i=len;i>=1;i--){
while(now<len&&d[a[now+1].x]<=m-d[a[i].x]) num[fa[a[++now].x]]++;
ans+=now-num[fa[a[i].x]];
}
vis[root]=true;
for(int j=fir[root];j;j=nxt[j]) if(!vis[son[j]]) get(son[j]);
}
int main(){
freopen("poj1741.in","r",stdin);
freopen("poj1741.out","w",stdout);
while(scanf("%d%d",&n,&m)==2&&n){
memset(fir,0,sizeof(fir)); tot=0;
memset(vis,0,sizeof(vis));
ans=0;
for(int i=1;i<=n-1;i++){
int x,y,z; scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
fa[1]=pre[1]=d[1]=len=0; dfs(1);
get(1);
printf("%d\n",ans/2);
}
return 0;
}