P3806 【模板】点分治1
题目背景
感谢hzwer的点分治互测。
题目描述
给定一棵有n个点的树
询问树上距离为k的点对是否存在。
输入格式
n,m 接下来n-1条边a,b,c描述a到b有一条长度为c的路径
接下来m行每行询问一个K
输出格式
对于每个K每行输出一个答案,存在输出“AYE”,否则输出”NAY”(不包含引号)
输入输出样例
输入 #1复制
2 1 1 2 2 2
输出 #1复制
AYE
说明/提示
对于30%的数据n<=100
对于60%的数据n<=1000,m<=50
对于100%的数据n<=10000,m<=100,c<=10000,K<=10000000
直接按照POJ 1742的代码改的,其他的没什么。
代码:
1 //树分治-点分治 2 #include<bits/stdc++.h> 3 using namespace std; 4 const int inf=1e9+10; 5 const int maxn=1e4+10; 6 const int maxm=1e7+10; 7 8 int head[maxn<<1],tot; 9 int root,allnode,n,m,k; 10 int vis[maxn],deep[maxn],dis[maxn],siz[maxn],maxv[maxn];//deep[0]子节点个数(路径长度),maxv为重心节点 11 int ans[maxm]; 12 13 struct node{ 14 int to,next,val; 15 }edge[maxn<<1]; 16 17 void add(int u,int v,int w)//前向星存图 18 { 19 edge[tot].to=v; 20 edge[tot].next=head[u]; 21 edge[tot].val=w; 22 head[u]=tot++; 23 } 24 25 void init()//初始化 26 { 27 memset(head,-1,sizeof head); 28 memset(vis,0,sizeof vis); 29 tot=0; 30 } 31 32 void get_root(int u,int father)//重心 33 { 34 siz[u]=1;maxv[u]=0; 35 for(int i=head[u];~i;i=edge[i].next){ 36 int v=edge[i].to; 37 if(v==father||vis[v]) continue; 38 get_root(v,u);//递归得到子树大小 39 siz[u]+=siz[v]; 40 maxv[u]=max(maxv[u],siz[v]);//更新u节点的maxv 41 } 42 maxv[u]=max(maxv[u],allnode-siz[u]);//保存节点size 43 if(maxv[u]<maxv[root]) root=u;//更新当前子树的重心 44 } 45 46 void get_dis(int u,int father)//获取子树所有节点与根的距离 47 { 48 deep[++deep[0]]=dis[u]; 49 for(int i=head[u];~i;i=edge[i].next){ 50 int v=edge[i].to; 51 if(v==father||vis[v]) continue; 52 int w=edge[i].val; 53 dis[v]=dis[u]+w; 54 get_dis(v,u); 55 } 56 } 57 58 void cal(int u,int now,int val) 59 { 60 dis[u]=now;deep[0]=0; 61 get_dis(u,0); 62 sort(deep+1,deep+deep[0]+1); 63 for(int i=1;i<=deep[0];i++){ 64 for(int j=1;j<=deep[0];j++){ 65 if(i!=j) ans[deep[i]+deep[j]]+=val; 66 } 67 } 68 } 69 70 void solve(int u) 71 { 72 vis[u]=1; 73 cal(u,0,1); 74 vis[u]=1; 75 for(int i=head[u];~i;i=edge[i].next){ 76 int v=edge[i].to; 77 int w=edge[i].val; 78 if(vis[v]) continue; 79 cal(v,w,-1); 80 allnode=siz[v]; 81 root=0; 82 get_root(v,u); 83 solve(v); 84 } 85 } 86 87 int main() 88 { 89 scanf("%d%d",&n,&m); 90 init(); 91 for(int i=1;i<n;i++){ 92 int u,v,w; 93 scanf("%d%d%d",&u,&v,&w); 94 add(u,v,w); 95 add(v,u,w); 96 } 97 root=0;allnode=n;maxv[0]=inf; 98 get_root(1,0); 99 solve(root); 100 while(m--){ 101 scanf("%d",&k); 102 if(ans[k]){ 103 printf("AYE\n"); 104 } 105 else{ 106 printf("NAY\n"); 107 } 108 } 109 return 0; 110 }