Neko and tree
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 524288/524288 K (Java/Others)
Total Submission(s): 0 Accepted Submission(s): 0
Problem Description
Neko has a tree with n nodes.
There are m key nodes on tree. Neko want to you to selecte some key nodes satisfying the distance between any two selected nodes less than or equal to k.
Neko thinks this work is too easy,so Neko want to know how many different way for selecting nodes.
Calculate the answer after mod 109+7.
Note that you have to select at least one key node.
Input
The first line contains three integers n,m,k(1≤n,m,k≤5000).
The next n−1 line, each line contains two integers u,v,indicating there is an edge connecting node u and node v.
The last line contains m integers, indicating key node.
Output
Output the number of way for selecting nodes.
Sample Input
4 3 2 1 2 1 3 2 4 2 3 1
Sample Output
7
用了一种非常暴力的方法,首先预处理树上两点之间的距离,然后枚举关键节点,对于关键节点x,与x距离小于等于k的关键节点个数为cnt,那么当前这个关键节点对答案的贡献就是2^cnt 所以 O(m^2)暴力统计就行了
(刚开始超时。。侥幸卡过去了)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn=1e4;
const int mod=1e9+7;
std::vector<int> G[maxn];
int dep[maxn],n,m,k,Fa[maxn][15],a[maxn];
ll ans;
inline void read(int &X)
{
X=0;int w=0;char ch=0;
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
X=w?-X:X;
}
ll fpow(ll a,ll b){ll res=1;while(b){if(b&1)res=res*a%mod;a=a*a%mod;b>>=1;}return res;}
inline void dfs(int u,int fa)
{
dep[u]=dep[fa]+1;Fa[u][0]=fa;
for(int i=1;(1<<i)<=dep[u];i++)Fa[u][i]=Fa[Fa[u][i-1]][i-1];
for(auto &v:G[u]){if(v==fa)continue;dfs(v,u);}
}
inline int lca(int x,int y)
{
if(dep[x]<dep[y])swap(x,y);
for(int i=14;i>=0;i--)if((1<<i)<=dep[x]-dep[y])x=Fa[x][i];
if(x==y)return x;
for(int i=14;i>=0;i--)if(Fa[x][i]!=Fa[y][i])x=Fa[x][i],y=Fa[y][i];
return Fa[x][0];
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
int main()
{
read(n);read(m);read(k);
for(int i=1;i<n;++i)
{
int u,v;read(u);read(v);
G[u].emplace_back(v);G[v].emplace_back(u);
}
dfs(1,0);
for(int i=0;i<m;++i)read(a[i]);
for(int i=0;i<m;++i)
{
ll cnt=0;
for(int j=i+1;j<m;++j)
if(dis(a[i],a[j])<=k) cnt++;
ans+=(fpow(2,cnt))%mod;
}
printf("%lld\n",ans%mod);
return 0;
}