题目大意:给定n个节点组成的树,树有边权,现在给定一个点u和v,dis(u,v)表示u和v节点的最近距离,问dis(u,v)<=k的uv对数,n <= 1万,k不定,权值<=1000。
解题思路:一看题目就觉得非常经典,不会做。这题是楼教主的男人八题之一,考察的是树上的分治和树形dp的思想。我想到的是一个n^2的暴力算法和一个n*k的背包解法,但是都是TLE。无奈之下只好去搜解题报告,然后就发现男人八题、树上分治等等高级词汇都出现,然后就确定这题十分经典。现在给出我的想法:
1、指定1为根将树变成有根树,那两个点的最近距离就有两种情况。其一,它们在同一个分支上,那么最近距离就是他们间的距离。其二,它们不在同一个分支上,那么他们的最短距离就是他们到最近公共祖先的距离和。每个根都可以算出相应的对数,并且根与其它根的计算并不冲突,这是一条重要的性质。
2、有了上面的分析基础接下来的就好办,先找到以某点为根的子孙节点到根的距离,然后从这些距离里面找出两个距离之和小等于k的方案数。这样似乎就可以了,但是有问题,如果他们在同一个分支,那么就会重复计算,比如1是根,1<-2<-3<-4,dis(4,1) = 10,dis(3,1) = 5,k = 20,那么dis(4,1)和dis(3,1)也合法,这显然不符合逻辑。我们应该减去这一在同一分支中的部分,具体解法是用1算出来的方案数减去根为2算出来的方案数也即要减去以子节点为根算出来的方案数,因为这部分方案里两个点到根的距离相加小等于k但不是我们要计算的当前根的方案数。
3、把距离都算出来之后,要怎么快速找到方案数呢?两个for循环就可能超时,想着O(n)解决,先对距离序列排序,然后找头尾两个数,如果符合情况,算中间的个数,然后从头的下一个开始算,如果不符合情况说明太大,要从尾的前一个开始计算,知道头等于尾.这样计算的复杂度是O(n),排序O(nlogn)。
4、这样仍然可能超时,如果这棵树是一条链,那么它要算n层,复杂度为n^2logn。怎么降低复杂度呢?因为每次以某点为根的计算都不影响其他点计算,那么我们每次都找重心,即可变成logn层!这篇论文讲得比我讲的详细Here
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
#define MAX 21000
#define INF 2147483647
#define max(a,b) (a)>(b)?(a):(b)
#define min(a,b) (a)<(b)?(a):(b)
struct node {
int v, len; //v表示邻接点,len表示边权
int sum, bal; //sum表示子孙节点个数,bal表示拆当前点后子树的最大结点数
node *next;
} tree[MAX], *head[MAX], tp[MAX];
int n, ptr, ans, root;
int tot, k, dist[MAX], vis[MAX];
int size[MAX], sign[MAX]; //size表示最大分支的结点数,sign是一个hash数组
void Initial() {
ans = ptr = 0;
for (int i = 0; i < MAX; ++i)
vis[i] = 0,head[i] = NULL;
}
void AddEdge(int a, int b, int c) {
tree[ptr].v = b, tree[ptr].len = c;
tree[ptr].next = head[a], head[a] = &tree[ptr++];
}
void Dfs(int s, int pa) {
tp[s].sum = tp[s].bal = 0;
node *p = head[s];
while (p != NULL) {
if (p->v != pa && vis[p->v] == 0) {
Dfs(p->v, s);
tp[s].sum += tp[p->v].sum; //累计子节点个数
tp[s].bal = max(tp[s].bal, tp[p->v].sum); //找最大分支
}
p = p->next;
}
tp[s].sum++; //自己
sign[tot] = s; //hash
size[tot++] = tp[s].bal; //记录每个最大分支的结点数
}
int GetRoot(int s) {
tot = 0, Dfs(s, 0);
int maxx = INF, maxi, cnt = tp[s].sum;
for (int i = 0; i < tot; ++i) {
size[i] = max(size[i], cnt - size[i]);
if (size[i] < maxx) {
maxx = size[i];
maxi = sign[i];
}
}
return maxi;
}
void GetDist(int s, int pa, int dis) {
//保存每个结点到根节点的距离
node *p = head[s];
dist[tot++] = dis;
while (p != NULL) {
if (p->v != pa && vis[p->v] == 0 && dis + p->len <= k)
GetDist(p->v, s, dis + p->len);
p = p->next;
}
}
void Count1(int s) {
sort(dist, dist + tot);
int left = 0, right = tot - 1;
while (left < right) {
if (dist[left] + dist[right] <= k)
ans += right - left, left++;
else right--;
}
}
void Count2(int s) {
vis[s] = 1;
node *p = head[s];
while (p != NULL) {
if (vis[p->v] == 0) {
tot = 0, GetDist(p->v, s, p->len);
sort(dist, dist + tot);
int left = 0, right = tot - 1;
while (left < right) {
if (dist[left] + dist[right] <= k)
ans -= right - left, left++;
else right--;
}
}
p = p->next;
}
}
void Solve(int s, int pa) {
root = GetRoot(s);
tot = 0,GetDist(root, 0, 0);
Count1(root);
Count2(root); // ans += count1 - coutn2;
node *p = head[root];
while (p != NULL) {
if (p->v != pa && vis[p->v] == 0)
Solve(p->v, root);
p = p->next;
}
}
int main()
{
int i, j, a, b, c;
while (scanf("%d%d", &n, &k),n + k) {
Initial();
for (i = 1; i < n; ++i) {
scanf("%d%d%d", &a, &b, &c);
AddEdge(a, b, c);
AddEdge(b, a, c);
}
Solve(1, 0);
printf("%d\n", ans);
}
}
本文ZeroClock原创,但可以转载,因为我们是兄弟。