As you know, an undirected connected graph with n nodes and n - 1 edges is called a tree. You are given an integer d and a tree consisting of n nodes. Each node i has a value ai associated with it.
We call a set S of tree nodes valid if following conditions are satisfied:
- S is non-empty.
- S is connected. In other words, if nodes u and v are in S, then all nodes lying on the simple path between u and v should also be presented in S.
.
Your task is to count the number of valid sets. Since the result can be very large, you must print its remainder modulo 1000000007 (109 + 7).
The first line contains two space-separated integers d (0 ≤ d ≤ 2000) and n (1 ≤ n ≤ 2000).
The second line contains n space-separated positive integers a1, a2, ..., an(1 ≤ ai ≤ 2000).
Then the next n - 1 line each contain pair of integers u and v (1 ≤ u, v ≤ n) denoting that there is an edge between u and v. It is guaranteed that these edges form a tree.
Print the number of valid sets modulo 1000000007.
1 4 2 1 3 2 1 2 1 3 3 4
8
0 3 1 2 3 1 2 2 3
3
4 8 7 8 7 5 4 6 4 10 1 6 1 2 5 8 1 3 3 5 6 7 3 4
41
In the first sample, there are exactly 8 valid sets: {1}, {2}, {3}, {4}, {1, 2}, {1, 3}, {3, 4} and {1, 3, 4}. Set {1, 2, 3, 4} is not valid, because the third condition isn't satisfied. Set {1, 4} satisfies the third condition, but conflicts with the second condition.
考虑以每个点作为根结点扩展出一棵树,这个树满足树上所有的节点的权值都不比树根大且val[root]-val[v]<=d,然后可以树型DP求以这个点为树根的集合数。如果以u为根时扩展的树中包含了与u权值相同的v,那么以v为根时便不能包括u了,这里需要判重
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<vector>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxn=2010;
const int MOD=1000000007;
vector<int> g[maxn];
int a[maxn];
int d,N,root;
LL dfs(int u,int f)
{
LL ans=1;
int len=g[u].size();
for(int i=0;i<len;i++)
{
int v=g[u][i];
if(v==f||a[v]>a[root]||(a[v]==a[root]&&v>root)||a[root]-a[v]>d)continue;
ans=(ans*(dfs(v,u)+1))%MOD;
}
return ans;
}
int main()
{
while(scanf("%d%d",&d,&N)!=EOF)
{
for(int i=1;i<=N;i++)scanf("%d",&a[i]);
for(int i=0;i<=N;i++)g[i].clear();
for(int i=1;i<N;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
LL ans=0;
for(int i=1;i<=N;i++)
{
root=i;
LL res=dfs(i,-1);
ans=(ans+res)%MOD;
}
cout<<ans<<endl;
}
return 0;
}