As you know, an undirected connected graph with n nodes and n - 1 edges is called a tree. You are given an integer d and a tree consisting of n nodes. Each node i has a value ai associated with it.
We call a set S of tree nodes valid if following conditions are satisfied:
- S is non-empty.
- S is connected. In other words, if nodes u and v are in S, then all nodes lying on the simple path between u and v should also be presented in S.
-
.
Your task is to count the number of valid sets. Since the result can be very large, you must print its remainder modulo 1000000007 (109 + 7).
The first line contains two space-separated integers d (0 ≤ d ≤ 2000) and n (1 ≤ n ≤ 2000).
The second line contains n space-separated positive integers a1, a2, ..., an(1 ≤ ai ≤ 2000).
Then the next n - 1 line each contain pair of integers u and v (1 ≤ u, v ≤ n) denoting that there is an edge between u and v. It is guaranteed that these edges form a tree.
Print the number of valid sets modulo 1000000007.
1 4 2 1 3 2 1 2 1 3 3 4
8
0 3 1 2 3 1 2 2 3
3
4 8 7 8 7 5 4 6 4 10 1 6 1 2 5 8 1 3 3 5 6 7 3 4
41
In the first sample, there are exactly 8 valid sets: {1}, {2}, {3}, {4}, {1, 2}, {1, 3}, {3, 4} and {1, 3, 4}. Set {1, 2, 3, 4} is not valid, because the third condition isn't satisfied. Set {1, 4} satisfies the third condition, but conflicts with the second condition.
这是一道比较好的树形dp。
我们先从要满足的条件入手,要求最大值减最小值小于等于d,于是我们考虑选区一个点为最大值时的集合。对于一个点如果它满足a[u]<=a[root]&&a[u]+d>=a[root],那这个点就能加入S中,所以我们可以考虑dp[i]表示经过i并且由它的子树构成的且满足root是最大值及a[u]+d>=a[root]这个条件的集合个数,所以转移时:dp[u]*=(dp[v]+1),其中v是u的儿子节点,然后把所有的情况加起来即可。
#include <cstdio>
#include <cstring>
#include <vector>
#include<iostream>
#include <algorithm>
#define mod 1000000007
#define maxn 2005
#define LL long long
using namespace std;
vector<int >g[maxn];
int a[maxn];
LL dp[maxn];
bool vis[maxn];
int root,d;
void dfs(int u,int pa){
int k=(int)g[u].size();
dp[u]=1;
for(int i=0;i<k;i++){
int v=g[u][i];
if(v==pa) continue;
if((a[v]<a[root]||(a[v]==a[root]&&!vis[v]))&&a[v]+d>=a[root]){
dfs(v,u);
dp[u]*=(dp[v]+1);
dp[u]%=mod;
}
}
}
int main()
{
int i,n,m,p;
cin>>d>>n;
for(i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for(i=0;i<n-1;i++){
cin>>p>>m;
g[p].push_back(m);
g[m].push_back(p);
}
LL s=0;
memset(vis,0,sizeof(vis));
for(i=1;i<=n;i++){
memset(dp,0,sizeof(dp));
root=i;
vis[i]=1;
dfs(i,-1);
s+=dp[i];
s%=mod;
}
cout<<s<<endl;
}