题意:
给你一棵
n
n
n个点带边权的树,有
m
m
m次询问,每次问你树上是否存在一条长度为
x
x
x的路径。
n
<
=
10000
,
m
<
=
100
,
x
<
=
1
e
6
n<=10000,m<=100,x<=1e6
n<=10000,m<=100,x<=1e6.
题解:
树上路径问题还是考虑用点分治。这个题让我发现我洛谷模板的点分治复杂度写假了,我模板的点分治复杂度是
n
2
l
o
g
n
n^2logn
n2logn的,却过了
10000
10000
10000,不知道是怎么造的数据。。。
之前的写法每一层会 n 2 n^2 n2合并,但是这个题我们会发现其实询问次数非常少。我们可以对于每一个询问点分治一次,但是这样常数会有点大,于是我们的办法是做一遍点分治,并在过程中处理所有询问。具体做法是,我们每次dfs每一棵子树,搜出当前子树所有出现过的链的长度,然后枚举每一个询问,看能不能与之前子树出现过的链长合并来拼出询问要的长度。然后去更新已经考虑过的子树信息,把这棵子树出现过的长度全都加进去,记录一下所有出现过的长度以便快速删除,这样保证每一层的复杂度是与当前层子树大小有关的,而不是与 n n n有关的。这样就做完了。
复杂度是 O ( n ∗ m ∗ l o g n ) O(n*m*logn) O(n∗m∗logn)。这个复杂度就比较真了,在洛谷上跑得也很快了。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,p,ans[10010],hed[100010],cnt,vis[100010];
int sz[10010],mx[10010],rt,shu,num;
long long q[1000010],book[1000010],jilu[1000010],ji[1000010];
struct node
{
int to,next;
long long dis;
}a[200010];
inline long long read()
{
long long x=0;
char s=getchar();
while(s>'9'||s<'0')
s=getchar();
while(s>='0'&&s<='9')
{
x=x*10+s-'0';
s=getchar();
}
return x;
}
inline void add(int from,int to,long long dis)
{
a[++cnt].to=to;
a[cnt].dis=dis;
a[cnt].next=hed[from];
hed[from]=cnt;
}
inline void getrt(int x,int f,int size)
{
sz[x]=1;
mx[x]=0;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(vis[y]||y==f)
continue;
getrt(y,x,size);
sz[x]+=sz[y];
mx[x]=max(mx[x],sz[y]);
}
mx[x]=max(mx[x],size-sz[x]);
if(mx[x]<mx[rt])
rt=x;
}
inline void dfs(int x,int f,long long dis)
{
if(dis>1e6)
return;
book[++shu]=dis;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(vis[y]||y==f)
continue;
dfs(y,x,dis+a[i].dis);
}
}
inline void solve(int x,int size)
{
vis[x]=1;
jilu[0]=1;
num=0;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(vis[y])
continue;
shu=0;
dfs(y,0,a[i].dis);
for(int j=1;j<=shu;++j)
{
for(int k=1;k<=p;++k)
{
if(q[k]>=book[j])
ans[k]|=jilu[q[k]-book[j]];
}
}
for(int j=1;j<=shu;++j)
{
jilu[book[j]]=1;
ji[++num]=book[j];
}
}
for(int i=1;i<=num;++i)
jilu[ji[i]]=0;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(vis[y])
continue;
int gg=sz[y];
if(sz[y]>sz[x])
gg=size-sz[x];
rt=0;
getrt(y,0,gg);
solve(rt,gg);
}
}
int main()
{
n=read();
p=read();
for(int i=1;i<=n-1;++i)
{
int x=read(),y=read();
long long z=read();
add(x,y,z);
add(y,x,z);
}
mx[0]=2e9;
getrt(1,0,n);
for(int i=1;i<=p;++i)
q[i]=read();
solve(1,n);
for(int i=1;i<=p;++i)
{
if(ans[i]||!q[i])
printf("Yes\n");
else
printf("No\n");
}
return 0;
}