问题描述:给定一棵N个结点的带权树, 问有多少条路径使得它的长度<=k,n<=10000,k<=1000000000。
题目分析:……本题是我第一次敲树上点分治,AC了有一点小激动,但代码可能还不是很正规……
好吧先来讲讲做法,有点像CDQ分治。我们只要每一次求出以root为根的子树中有多少路径经过了root且长度<=k的,然后再递归root的子树即可。我们先Dfs一遍root为根的子树,求出每一个节点i到root的距离dep[i]。我们知道,如果dep[u]+dep[v]<=k还不行,我们必须要让u,v来自root的不同子树才行。于是我们记录belong[u]代表u来自以belong[u]为根的子树,其中belong[u]是root的儿子。这样当belong[u]!=belong[v]的时候,u->v的路径就可以被统计进答案。注意u==root或v==root的情况要小心。
于是我们对子树的dep从小到大排个序,然后利用单调,记录一个最大的tail使得dep[i]+dep[tail]<=k,让tail不断左移。我们还要用一个下标数组cnt[i]记录一下1——tail中belong为i的个数,tail减的时候更新这个数组就可以了。注意最后一定要让tail减到0,因为让tail减到0的时间代价比清空cnt数组的代价要小得多。如果每一次都清空cnt数组的话,等一会儿我们画出递归树来就知道,这是O(n^2)的。
然而这样时间复杂度是多少呢?对于每一次操作,我们都要用O(k*log(k))的时间,其中k是子树的大小,万一这是一条链的话,时间就是O(n^2*log(n))的了因为递归的深度会达到n。
我们可以每一次都选取树的重心,作为树的root,这样递归深度就是log(n)的了。然而如何得到树的重心呢?对于每一次的子树,我们都做一次树形DP。先随机选取一个点为根root,然后记录以每一个点i为根的子树大小size[i],以及它儿子中最大的size,记为maxson[i]。这样当以i为根的时候,此时它最大的儿子拥有的点数为max(maxson[i],size[root]-size[i])。然后在子树中所有的点中取min就好了。以它作为新的root再来一遍Dfs……QAQ
以上的做法时间均与子树大小成正比。只不过很麻烦……我们画出算法的递归树:
很明显,一个点只会被计算log(n)次,故时间为O(n*log^2(n))。注意我们不能在递归时清空任何一个全局变量数组,否则时间为O(n^2)!
还有,做完一个点之后,我们要断开它和它儿子的所有边,于是要对空间池和反向边做一些标记,超级麻烦……
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=10010;
struct data
{
int obj,val,id,rev,_Next;
} e[maxn<<1];
int head[maxn];
int cur;
int efa[maxn];
int fa[maxn];
int nTree[maxn];
int curT;
int max_son[maxn];
int _Size[maxn];
int root;
struct Tnode
{
int dep,belong;
} A[maxn];
int cnt[maxn];
int n,k;
int ans;
void Add(int x,int y,int v,int Rev)
{
cur++;
e[cur].obj=y;
e[cur].val=v;
e[cur].id=1;
e[cur].rev=Rev;
e[cur]._Next=head[x];
head[x]=cur;
}
void Dfs(int node)
{
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if (son!=fa[node])
{
fa[son]=node;
efa[son]=p;
Dfs(son);
}
p=e[p]._Next;
}
}
bool Comp(Tnode x,Tnode y)
{
return x.dep<y.dep;
}
void Dfs1(int node,int nfa)
{
curT++;
nTree[curT]=node;
max_son[node]=0;
_Size[node]=1;
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if ( son!=nfa && e[p].id )
{
Dfs1(son,node);
max_son[node]=max(max_son[node],_Size[son]);
_Size[node]+=_Size[son];
}
p=e[p]._Next;
}
}
void Dfs2(int node,int nfa,int Aid)
{
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if ( son!=nfa && e[p].id )
{
curT++;
A[curT].dep=A[Aid].dep+e[p].val;
if (node==root) A[curT].belong=son;
else A[curT].belong=A[Aid].belong;
Dfs2(son,node,curT);
}
p=e[p]._Next;
}
}
void Solve(int node)
{
curT=0;
Dfs1(node,node);
if (curT<=1) return;
root=node;
for (int i=2; i<=curT; i++)
if ( max(max_son[ nTree[i] ],_Size[node]-_Size[ nTree[i] ]) <
max(max_son[root],_Size[node]-_Size[root]) )
root=nTree[i];
curT=1;
A[1].dep=0;
A[1].belong=0;
Dfs2(root,root,1);
sort(A+1,A+curT+1,Comp);
for (int i=1; i<=curT; i++) cnt[ A[i].belong ]++;
int tail=curT;
for (int i=1; i<=curT; i++)
{
while ( A[i].dep+A[tail].dep>k && tail )
{
cnt[ A[tail].belong ]--;
tail--;
}
ans+=(tail-cnt[ A[i].belong ]);
}
while (tail)
{
cnt[ A[tail].belong ]--;
tail--;
}
node=root;
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if ( e[p].id )
{
if (son==fa[node])
{
e[ efa[node] ].id=0;
e[ efa[node]+e[ efa[node] ].rev ].id=0;
}
else
{
e[ efa[son] ].id=0;
e[ efa[son]+e[ efa[son] ].rev ].id=0;
}
Solve(son);
}
p=e[p]._Next;
}
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
scanf("%d%d",&n,&k);
while ( n!=0 || k!=0 )
{
cur=-1;
for (int i=1; i<=n; i++) head[i]=-1;
for (int i=1; i<n; i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
Add(a,b,c,1);
Add(b,a,c,-1);
}
fa[1]=1;
efa[1]=0;
Dfs(1);
ans=0;
Solve(1);
printf("%d\n",ans/2);
scanf("%d%d",&n,&k);
}
return 0;
}
当然,我们还有一种比较暴力的做法(虽然时间复杂度一样),就是信息学竞赛中很常用的treap+启发式合并。我们对于每一个节点都建立一个treap,存储以它为根的子树的所有节点的深度dep[i],我们还要查看以当前节点为lca的长度<=k的(u,v)路径有多少条。我们在得到每一个当前节点node的儿子son的时候,先不要急着把son的treap加进node的treap中,我们先统计答案。假设后者的size大于前者,我们就遍历一遍son的treap,取出每一个dep[u](u在son的子树中),然后查找当前node的treap中有多少个dep[v]<=k+2*dep[node]-dep[u],由于u,v来自node的两棵不同的子树,所以不会重复统计答案。做完一个son之后,把它的treap合并,再做下一个son。时间复杂度O(n*log^2(n))。
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<ctime>
using namespace std;
const int maxn=10010;
const int maxl=18;
struct Tnode
{
Tnode *lson,*rson;
int val,fix;
int _Size;
} tree[maxn*maxl];
Tnode *root[maxn];
int cur;
struct data
{
int obj,len,_Next;
} e[maxn<<1];
int head[maxn];
int Ecur;
int dep[maxn];
int ans;
int n,k;
void Add(int x,int y,int v)
{
Ecur++;
e[Ecur].obj=y;
e[Ecur].len=v;
e[Ecur]._Next=head[x];
head[x]=Ecur;
}
Tnode *New_node(int v)
{
cur++;
tree[cur].lson=tree[cur].rson=NULL;
tree[cur].val=v;
tree[cur].fix=rand();
tree[cur]._Size=1;
return tree+cur;
}
void Recount(Tnode *&P)
{
int temp=1;
if (P->lson) temp+=(P->lson->_Size);
if (P->rson) temp+=(P->rson->_Size);
P->_Size=temp;
}
void Right_turn(Tnode *&P)
{
Tnode *W=P->lson;
P->lson=W->rson;
W->rson=P;
P=W;
Recount(P->rson);
Recount(P);
}
void Left_turn(Tnode *&P)
{
Tnode *W=P->rson;
P->rson=W->lson;
W->lson=P;
P=W;
Recount(P->lson);
Recount(P);
}
void Insert(Tnode *&P,int v)
{
if (!P) P=New_node(v);
else
{
if ( v <= P->val )
{
Insert(P->lson,v);
if ( P->lson->fix < P->fix ) Right_turn(P);
}
else
{
Insert(P->rson,v);
if ( P->rson->fix < P->fix ) Left_turn(P);
}
Recount(P);
}
}
void Calc(Tnode *&P,int v)
{
if (!P) return;
if (P->val<=v)
{
ans++;
if (P->lson) ans+=(P->lson->_Size);
Calc(P->rson,v);
}
else Calc(P->lson,v);
}
void Work(Tnode *&Root,Tnode *&P,int v)
{
Calc(Root,v-(P->val));
if (P->lson) Work(Root,P->lson,v);
if (P->rson) Work(Root,P->rson,v);
}
void Cut(Tnode *&Root,Tnode *&P)
{
Insert(Root,P->val);
if (P->lson) Cut(Root,P->lson);
if (P->rson) Cut(Root,P->rson);
}
void Update(Tnode *&x,Tnode *&y,int lcadep)
{
Work(x,y,k+2*lcadep);//统计答案
Cut(x,y);//粘贴y的子树到x的子树中
}
void Dfs(int node,int fa)
{
root[node]=NULL;
Insert(root[node],dep[node]);//初始化以当前节点的treap
int p=head[node];
while (p!=-1)
{
int son=e[p].obj;
if (son!=fa)
{
dep[son]=dep[node]+e[p].len;
Dfs(son,node);
if ( root[node]->_Size > root[son]->_Size )
Update(root[node],root[son],dep[node]);
else
{
Update(root[son],root[node],dep[node]);
root[node]=root[son];
}//启发式合并
}
p=e[p]._Next;
}
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
srand(time(0));
scanf("%d%d",&n,&k);
while ( n!=0 || k!=0 )
{
cur=-1;
Ecur=-1;
ans=0;
for (int i=1; i<=n; i++) head[i]=-1;
for (int i=1; i<n; i++)
{
int a,b,v;
scanf("%d%d%d",&a,&b,&v);
Add(a,b,v);
Add(b,a,v);
}
dep[1]=0;
Dfs(1,1);
printf("%d\n",ans);
scanf("%d%d",&n,&k);
}
return 0;
}