树
时间限制: 1000 M S 1000MS 1000MS 内存限制: 30 M B 30MB 30MB
题目描述
给一个有
n
n
n个顶点的树,每个边都有一个长度(正整数小于
1001
1001
1001)。
定义
d
i
s
t
(
u
,
v
)
=
dist(u,v)=
dist(u,v)=节点
u
u
u和
v
v
v之间的最短距离。
给出一个整数
k
k
k,当且仅当
d
i
s
t
(
u
,
v
)
dist(u,v)
dist(u,v)不超过
k
k
k时,每个
(
u
,
v
)
(u,v)
(u,v)的点对被称为有效。
编写一个程序,计算对给定树的有效点对对数。
输入格式
输入包含几个测试用例。每个测试用例的第一行包含两个整数
n
,
k
n,k
n,k。
(
n
≤
10000
)
(n ≤ 10000)
(n≤10000)以下
n
−
1
n-1
n−1行各包含三个整数
u
,
v
,
l
u,v,l
u,v,l,这意味着在节点
u
u
u和
v
v
v之间存在长度为
l
l
l的边。
最后一个测试用例后跟两个零。
输出格式
对于每个测试用例,在单行上输出答案。
样例输入
5
5
5
4
4
4
1
1
1
2
2
2
3
3
3
1
1
1
3
3
3
1
1
1
1
1
1
4
4
4
2
2
2
3
3
3
5
5
5
1
1
1
0
0
0
0
0
0
样例输出
8 8 8
点分治是用来解决树上路径问题的一个强有力的工具。
如果指定根节点
r
t
rt
rt,那么对于树上的每一条路径都可以分成两种:
- 经过根节点 r t rt rt
- 被包含于根节点
r
t
rt
rt的某棵子树中
对于第二种路径,我们显然可以将其作为原问题的子问题递归求解。
对于第一种路径,我们又可以将其分成两类:
- 路径的一端是 r t rt rt
- 路径的两端都在 r t rt rt的某棵子树中
那么,我们可以将第二类路径拆分成
u
u
u ~
r
t
rt
rt和
r
t
rt
rt ~
v
v
v的两段。
如何统计答案呢?
我们可以先预处理出每个子孙节点到
r
t
rt
rt的距离,
然后在不同的子树中查询距离和
≤
k
\leq k
≤k的点对。
一种方法是用树状数组计数。
对于每棵子树中的每个节点
u
u
u,都去它之前的子树中查找
k
−
d
[
u
]
k-d[u]
k−d[u]的数量。
但遗憾的是计数范围达到了
1
0
7
10^7
107级别,显然树状数组无法支持(而且这很难离散化)
于是随便拿棵平衡树乱维护一下吧
关于根节点的选取,我们可以选取重心,这样,递归层数就不会超过
O
(
l
g
n
)
O(lgn)
O(lgn)层
每层复杂度
O
(
n
)
O(n)
O(n)
所以点分治复杂度
O
(
n
l
o
g
n
)
O(n log n)
O(nlogn)
加上
T
r
e
a
p
Treap
Treap的时间复杂度
总时间复杂度
O
(
n
l
o
g
2
n
)
O(n log^2 n)
O(nlog2n)
代码
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 10005;
const int maxe = 10005;
int edgenum;
int Next[maxe << 1] , vet[maxe << 1] , val[maxe << 1] , head[maxn];
int rt , tot_siz;
int siz[maxn] , son[maxn];
bool vis[maxn];
vector < int > dis;
int d[maxn];
int k;
ll ans;
int max(int x , int y) {return x > y ? x : y;}
int read() {
char ch = getchar(); bool f = 1;
while(ch < '0' || ch > '9') f &= ch != '-' , ch = getchar();
int res = 0;
while(ch >= '0' && ch <= '9') res = (res << 3) + (res << 1) + (ch ^ 48) , ch = getchar();
return f ? res : -res;
}
class Treap {
private:
static const int maxt = 10005;
struct node {int son[2] , val , pri , cnt , siz;}t[maxt];
int cnt;
int build(int val) {
t[++cnt].val = val; t[cnt].pri = rand();
t[cnt].cnt = t[cnt].siz = 1;
t[cnt].son[0] = t[cnt].son[1] = 0;
return cnt;
}
void update(int p) {t[p].siz = t[t[p].son[0]].siz + t[p].cnt + t[t[p].son[1]].siz;}
void rotate(int &p , int d) {
int k = t[p].son[d];
t[p].son[d] = t[k].son[d ^ 1];
t[k].son[d ^ 1] = p;
update(p); update(p = k);
}
public:
int rt;
void clear() {rt = cnt = 0;}
void insert(int &p , int val) {
if(!p) {p = build(val); return;}
t[p].siz++;
if(t[p].val == val) {t[p].cnt++; return;}
int d = t[p].val < val;
insert(t[p].son[d] , val);
if(t[p].pri > t[t[p].son[d]].pri) rotate(p , d);
}
int rnk(int p , int val) {
if(!p) return 0;
if(t[p].val > val) return rnk(t[p].son[0] , val);
if(t[p].val == val) return t[t[p].son[0]].siz + t[p].cnt;
return t[t[p].son[0]].siz + t[p].cnt + rnk(t[p].son[1] , val);
}
}treap;
void clear_edge(int n) {
edgenum = 0;
for(int i = 1; i <= n; i++) head[i] = 0;
}
void add_edge(int u , int v , int cost) {
Next[++edgenum] = head[u];
vet[edgenum] = v;
val[edgenum] = cost;
head[u] = edgenum;
}
void get_rt(int u , int fa) {
siz[u] = 1; son[u] = 0;
for(int e = head[u]; e; e = Next[e]) {
int v = vet[e];
if(v == fa || vis[v]) continue;
get_rt(v , u);
siz[u] += siz[v];
son[u] = max(son[u] , siz[v]);
}
son[u] = max(son[u] , tot_siz - siz[u]);
if(son[u] < son[rt]) rt = u;
}
void dfs(int u , int fa) {
for(int e = head[u]; e; e = Next[e]) {
int v = vet[e];
if(v == fa || vis[v]) continue;
dis.push_back(d[v] = d[u] + val[e]);
dfs(v , u);
}
}
void solve(int u) {
vis[u] = 1; treap.clear();
for(int e = head[u]; e; e = Next[e]) {
int v = vet[e];
if(vis[v]) continue;
dis.clear(); dis.push_back(d[v] = val[e]);
dfs(v , 0);
int dis_siz = dis.size();
for(int i = 0; i < dis_siz; i++) ans += treap.rnk(treap.rt , k - dis[i]);
for(int i = 0; i < dis_siz; i++) treap.insert(treap.rt , dis[i]);
}
ans += treap.rnk(treap.rt , k);
for(int e = head[u]; e; e = Next[e]) {
int v = vet[e];
if(vis[v]) continue;
rt = 0; tot_siz = siz[v];
get_rt(v , 0);
solve(rt);
}
}
int main() {
int n;
for(n = read() , k = read(); n && k; n = read() , k = read()) {
clear_edge(n);
for(int i = 1; i < n; i++) {
int u = read() , v = read() , cost = read();
add_edge(u , v , cost);
add_edge(v , u , cost);
}
for(int i = 1; i <= n; i++) vis[i] = 0;
rt = 0; tot_siz = n;
son[0] = n;
get_rt(1 , 0);
ans = 0;
solve(rt);
printf("%lld\n",ans);
}
return 0;
}