D. Valid Sets
time limit per test
1 second
memory limit per test
256 megabytes
input
standard input
output
standard output
As you know, an undirected connected graph with n nodes and n - 1 edges is called a tree. You are given an integer d and a tree consisting of n nodes. Each node i has a value ai associated with it.
We call a set S of tree nodes valid if following conditions are satisfied:
- S is non-empty.
- S is connected. In other words, if nodes u and v are in S, then all nodes lying on the simple path between u and v should also be presented in S.
.
Your task is to count the number of valid sets. Since the result can be very large, you must print its remainder modulo 1000000007(109 + 7).
Input
The first line contains two space-separated integers d (0 ≤ d ≤ 2000) and n (1 ≤ n ≤ 2000).
The second line contains n space-separated positive integers a1, a2, ..., an(1 ≤ ai ≤ 2000).
Then the next n - 1 line each contain pair of integers u and v (1 ≤ u, v ≤ n) denoting that there is an edge between u and v. It is guaranteed that these edges form a tree.
Output
Print the number of valid sets modulo 1000000007.
Examples
input
Copy
1 4 2 1 3 2 1 2 1 3 3 4
output
Copy
8
input
Copy
0 3 1 2 3 1 2 2 3
output
Copy
3
input
Copy
4 8 7 8 7 5 4 6 4 10 1 6 1 2 5 8 1 3 3 5 6 7 3 4
output
Copy
41
Note
In the first sample, there are exactly 8 valid sets: {1}, {2}, {3}, {4}, {1, 2}, {1, 3}, {3, 4} and {1, 3, 4}. Set {1, 2, 3, 4} is not valid, because the third condition isn't satisfied. Set {1, 4} satisfies the third condition, but conflicts with the second condition.
题意:
给一棵树,点上有权值,以及d,要求有多少种联通块满足最大值减最小值小于等于d。结果 %1000000007
思路:
枚举每一个点为根节点,而且是最大值。从根节点开始遍历子节点,每个子节点必须满足下列性质
1.每个子节点的权值<=根节点的权值,特别的如果等于根节点的权值,那么必须要求子节点编号小于根节点编号。(避免重复)
2.最大值-最小值<=d
按照乘法原理,每个点都可以选或者不选
假设u是父节点 v是子节点 那么已知dp[u]=dp[u]*dp[v] 特殊的v如果是叶子 那么dp[v]=2 即可选可不选 如果u是不是根 那么dp[u]最后要+1 因为可以不选u
#include <cstdio>
#include <algorithm>
#include <math.h>
#include <string.h>
#include <vector>
#include <iostream>
#define ll long long
#define INF 0x3f3f3f
using namespace std;
const int N=2000+500;
const int mod=1000000007;
vector<int> v[N];
int a[N];
ll dp[N];
int d;
void dfs(int x,int w,int root,int fa)
{
int to,i;
dp[x]=1;
for(i=0;i<v[x].size();i++)
{
to=v[x][i];
if(a[to]==w && to>root) continue;
if(a[to]>w || w-a[to]>d || to==fa) continue;
dfs(to,w,root,x);
dp[x]=(dp[x]*(dp[to]+1))%mod;
}
}
int main()
{
int n,i,uu,vv;
ll ans;
ans=0;
scanf("%d%d",&d,&n);
for(i=1;i<=n;i++) scanf("%d",&a[i]);
for(i=1;i<=n-1;i++)
{
scanf("%d%d",&uu,&vv);
v[uu].push_back(vv);
v[vv].push_back(uu);
}
for(i=1;i<=n;i++)
{
memset(dp,0,sizeof(dp));
dfs(i,a[i],i,i);
ans=(ans+dp[i])%mod;
}
printf("%lld\n",ans);
return 0;
}