题目:
题目链接:https://www.luogu.com.cn/problem/P4178
给你一棵树,以及这棵树上边的距离,问有多少对点它们两者间的距离小于等于
k
k
k。
思路:
这道题是应该加强数据了。。。
当然如果这道题不打算做点分治模板题的话可以不用
O
(
n
n
log
n
)
O(n\sqrt n\ \log \sqrt n)
O(nn logn)分块在洛谷优秀的
O
2
O2
O2下过了。。。
我们假设
1
1
1为树根,
d
f
s
dfs
dfs遍历一遍就可以得到每一个节点
x
x
x到
1
1
1的距离
d
i
s
[
x
]
dis[x]
dis[x]。
然后可以暴力判断
d
i
s
dis
dis中有多少个数时小于等于
k
k
k的。
每一个点作根跑一边,时间复杂度
O
(
n
2
)
O(n^2)
O(n2)。
考虑优化上述算法。
考虑换根。我们把根从
1
1
1转移到
x
x
x时,所有
x
x
x子树内的节点到
x
x
x的距离减少了
d
i
s
(
1
,
x
)
dis(1,x)
dis(1,x),其他点到
x
x
x距离增加了
d
i
s
(
1
,
x
)
dis(1,x)
dis(1,x)。
用
d
f
s
dfs
dfs序把每一棵子树的编号变为连续的,那么我们只要在原
d
i
s
dis
dis序列中进行区间加减操作就可以了维护出以
x
x
x为根时,每一个点到
x
x
x的距离。
那么我们现在需要求整个
d
i
s
dis
dis数组中有多少个是
≤
k
\leq k
≤k的,那么其实就是 教主的魔法 那道题了。
采用分块,区间加减容易实现,对于每一个块维护一个
S
o
r
t
Sort
Sort数组,表示这个块的
d
i
s
dis
dis排序后的数组。那么每次修改时,
S
o
r
t
Sort
Sort数组可以在
O
(
n
log
n
)
O(\sqrt n\ \log \sqrt n)
O(n logn)的时间复杂度内暴力维护。
注意如果要把一个块整体加减,那么直接在这个块的
a
d
d
add
add数组中加减即可,在
d
i
s
dis
dis和
S
o
r
t
Sort
Sort中都不需要修改。
询问时枚举每一个块
p
p
p,在这个块的
S
o
r
t
Sort
Sort中二分出第一个严格大于
k
−
a
d
d
[
p
]
k-add[p]
k−add[p]的数的位置
p
o
s
pos
pos,那么这个块中小于等于
k
k
k的数字就有
p
o
s
−
1
pos-1
pos−1个。
这样就可以在
O
(
n
log
n
)
O(\sqrt n\ \log \sqrt n)
O(n logn)的时间复杂度内求出距离一个点有多少个点路径长度
≤
k
\leq k
≤k。总时间复杂度
O
(
n
n
log
n
)
O(n\sqrt n\ \log \sqrt n)
O(nn logn)。常数极大。
代码:
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=40010,M=210;
int L[M],R[M],pos[N],add[M],Sort[M][M],head[N],dis[N],Dis[N],rk[N],dfn[N],size[N];
int n,m,T,tot,ans;
struct edge
{
int next,to,dis;
}e[N*2];
void addedge(int from,int to,int d)
{
e[++tot].to=to;
e[tot].dis=d;
e[tot].next=head[from];
head[from]=tot;
}
void dfs1(int x,int fa) //求出每一个点到点1的距离,同时求出dfs序
{
dfn[x]=++tot; rk[tot]=x; size[x]=1;
for (register int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
Dis[v]=Dis[x]+e[i].dis;
dfs1(v,x);
size[x]+=size[v];
}
}
}
void update(int l,int r,int val) //分块区间加模板
{
int p=pos[l],q=pos[r];
if (q-p<=1)
{
for (register int i=l;i<=r;i++) dis[i]+=val;
for (register int i=L[p];i<=R[q];i++) Sort[pos[i]][i-L[pos[i]]+1]=dis[i]; //整个块暴力修改
sort(Sort[p]+1,Sort[p]+1+R[p]-L[p]+1);
if (p!=q) sort(Sort[q]+1,Sort[q]+1+R[q]-L[q]+1); //重新维护
return;
}
for (register int i=l;i<=R[p];i++) dis[i]+=val;
for (register int i=L[q];i<=r;i++) dis[i]+=val;
for (register int i=L[p];i<=R[p];i++) Sort[p][i-L[p]+1]=dis[i];
for (register int i=L[q];i<=R[q];i++) Sort[q][i-L[q]+1]=dis[i]; //两边暴力修
sort(Sort[p]+1,Sort[p]+1+R[p]-L[p]+1);
sort(Sort[q]+1,Sort[q]+1+R[q]-L[q]+1);
for (register int i=p+1;i<q;i++)
add[i]+=val;
}
void dfs2(int x,int fa) //换根求答案
{
for (register int i=1;i<=T;i++)
ans+=upper_bound(Sort[i]+1,Sort[i]+1+R[i]-L[i]+1,m-add[i])-Sort[i]-1;
for (register int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
update(dfn[v],dfn[v]+size[v]-1,-e[i].dis*2);
update(1,n,e[i].dis); //换根
dfs2(v,x);
update(dfn[v],dfn[v]+size[v]-1,e[i].dis*2);
update(1,n,-e[i].dis);
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (register int i=1,x,y,z;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
addedge(x,y,z); addedge(y,x,z);
}
scanf("%d",&m);
tot=0; T=sqrt(n);
if (T*T<n) T++;
dfs1(1,0);
for (register int i=1;i<=T;i++)
{
L[i]=R[i-1]+1; R[i]=min(n,T*i);
for (register int j=L[i];j<=R[i];j++)
dis[j]=Sort[i][j-L[i]+1]=Dis[rk[j]],pos[j]=i;
sort(Sort[i]+1,Sort[i]+1+R[i]-L[i]+1);
}
dfs2(1,0);
printf("%d",(ans-n)/2); //先减去n个(x,x)的点对,然后(x,y)和(y,x)只算一个,所以除以2
return 0;
}