poi 1741
题目:http://poj.org/problem?id=1741
题意:给你一棵最多 10^4 个点组成的树,每根树枝的长度最多为 10^3 ,问你两个点之间的距离<=k 的点对数。
思路:
楼教男人八题之一。。 显然,O(N^2) 找点对的方法是不行的,而 O(NK) (k<=10^9)的动态规划也是不行的。好吧,具体思路参见漆子超的论文:http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html###,很好的一篇论文,在树上用分治来做。在减少递归的层次上,用了一个技巧,就是每次都找这棵树的重心,其在最差情况下的层数为 O(logN),如果不这样做,那么最差情况,即一根链的时候,层数高达O(N),会TLE。算出每个点到根节点的距离后,匹配找出 Depth[i]+Depth[j]<=K 时,先将 dis 数组排序,然后从两边开始找,这样找的时间复杂度为O(N),而排序是 O(logN),所以总的是O(NlogN)。综上,该算法总的时间复杂度为 Nlog(N)*log(N)。
看了论文以后,第一遍做,过了样例,一交,还是TLE了,一看别人AC的,因为是双向边,我用了一个 vis 数组,这样每次都要清零,耗费了很多时间,其实只要记录它的father ,不要走回去就好。改过来后,交了,WA了。这次检查对照了好久,直到找到下面 6 6 这组数据时才发现,原来是在找 Depth[i]+Depth[j]<=K且Belong[i]=Belong[j]数对(i,j)的个数时候错了。因为我每次都是dfs2(u,0,0),再算del的时候,应该是的dfs2(u,0,len),连接u、v的那条边要算进去,作为dis加上去的初始值。只能说,最后AC的时候好开心,又向男人进了一步。。 = =
代码如下:
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int INF = 0x0fffffff ;
const int MAXN = 11111 ;
int n,m;
struct Edge
{
int t,next,len;
} edge[MAXN<<1];
int head[MAXN],tot;
void add_edge(int s,int t,int len)
{
edge[tot].t=t;
edge[tot].len=len;
edge[tot].next = head[s];
head[s] = tot++;
}
int root;
int getted[MAXN];
vector <int> node;
int num[MAXN],maxv[MAXN];
void dfs1(int u,int fa)
{
node.push_back(u);
num[u]=1;
maxv[u]=0;
for(int e = head[u] ;e!=-1;e=edge[e].next)
{
int v = edge[e].t;
if(getted[v]||v==fa) continue;
dfs1(v,u);
num[u]+=num[v];
maxv[u]=max(maxv[u],num[v]);
}
}
void get_root(int x)
{
node.clear();
dfs1(x,0);
int minn=INF;
int sum_node = num[x];
for(int i=0;i<node.size();i++)
{
int cur = node[i];
maxv[cur] = max(maxv[cur],sum_node-num[cur]);
if(maxv[cur]<minn)
{
minn = maxv[cur];
root = cur;
}
}
}
vector <int> dis;
void dfs2(int u,int fa,int s)
{
dis.push_back(s);
for(int e = head[u];e!=-1;e =edge[e].next)
{
int v = edge[e].t;
int len = edge[e].len;
if(getted[v]||v==fa||s+len>m) continue;
dfs2(v,u,s+len);
}
}
void get_dis(int u,int dist)
{
dis.clear();
dfs2(u,0,dist);
}
int ans;
void count_add()
{
get_dis(root,0);
sort(dis.begin(),dis.end());
int j=dis.size()-1;
for(int i = 0;i<j;)
{
if(dis[i]+dis[j]<=m)
{
ans+=j-i;
i++;
}
else
{
j--;
}
}
/*printf("dis+\n");
for(int i=0;i<dis.size();i++)
printf("%d ",dis[i]);
puts("");*/
}
void count_del()
{
for(int e = head[root] ; e!=-1;e=edge[e].next)
{
int v = edge[e].t;
int len = edge[e].len;
if(getted[v]) continue;
get_dis(v,len);
sort(dis.begin(),dis.end());
/*
puts("dis-");
for(int i=0;i<dis.size();i++)
printf("%d ",dis[i]);
puts("");*/
int j = dis.size()-1;
for(int i=0;i<j;)
{
if(dis[i]+dis[j]<=m)
{
ans-=(j-i);
i++;
}
else j--;
}
}
}
void solve(int x)
{
get_root(x);
getted[root] = 1;
//printf("root = %d\n",root);
count_add();
//printf("ans1 = %d\n",ans);
count_del();
//printf("ans2 = %d\n",ans);
for(int e = head[root] ; e != -1; e =edge[e].next)
{
int v = edge[e].t;
if(getted[v]) continue;
solve(v);
}
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
if(n+m==0) break;
memset(head,-1,sizeof(head));
tot=0;
int a,b,c;
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&a,&b,&c);
add_edge(a,b,c);
add_edge(b,a,c);
}
memset(getted,0,sizeof(getted));
ans=0;
solve(1);
printf("%d\n",ans);
}
return 0;
}
/*
5 4
1 2 3
1 3 1
1 4 2
3 5 1
6 6
1 2 3
1 3 1
1 4 2
3 5 1
5 6 1
0 0
*/