题目大意:
就是给你一颗 n n n个点的树,树上有 m m m个关键点,你可以选择若干个关键点组成集合 S S S,这个集合满足任意两点在树上的距离不超过 k k k,问你有多少种选法?
解题思路:
我们考虑树形dp
1.这个状态方程比较难想:
d
p
[
i
]
[
j
]
:
表
示
在
第
i
个
点
所
在
的
子
树
中
关
键
节
点
距
离
i
点
最
远
的
距
离
为
j
的
选
择
方
案
数
dp[i][j]:表示在第i个点所在的子树中关键节点距离i点最远的距离为j的选择方案数
dp[i][j]:表示在第i个点所在的子树中关键节点距离i点最远的距离为j的选择方案数
这
看
起
来
是
不
是
会
算
重
复
实
际
上
,
并
没
有
重
复
因
为
它
表
示
的
是
选
择
方
案
内
要
包
含
一
定
要
包
含
距
离
i
的
距
离
为
j
的
点
这看起来是不是会算重复实际上,并没有重复因为它表示的是选择方案内要包含一定要包含距离i的距离为j的点
这看起来是不是会算重复实际上,并没有重复因为它表示的是选择方案内要包含一定要包含距离i的距离为j的点
2.这个树形dp和常规的树形dp不一样因为它是选点的方案数,那么我们直接一遍dfs,用子节点的答案去更新父节点的答案就好
3.那么我们转移?v
backup[u]是把dp[u]拷贝一边
u是父节点,v是子节点
最后把backup再拷贝会dp中
1.backup[u][max(i,j+1)] += dp[u][i] * dp[v][j]
解释:就是我们对于后面就是距离u节点最远距离为i的方案数*距离v最远距离为j的方案数(组合)
1.1那为什么是max(i,j+1)呢?
解释:因为你看状态转移方程的意义就是最远点的方案数,那么答案就是累加到距离最远的下面。
1.2那为什么是j+1,因为要加上自己到父节点的那个长度。
2.backup[u][i+1] += dp[v][i] 就是直接在把只在v子树里面选择的答案也累加上去
3.如果节点u是一个关键点的话 d p [ u ] [ 0 ] = 1 , d p [ u ] [ i ] = 2 [ i = 1..... k ] dp[u][0] = 1, dp[u][i]=2[i=1.....k] dp[u][0]=1,dp[u][i]=2[i=1.....k]因为这个节点本身解可以实现选和不选一共和其他选择方案组合就是*2.
4.最后答案累计 ∑ i = 1 n d p [ 1 ] [ i ] \sum_{i=1}^{n}dp[1][i] ∑i=1ndp[1][i]
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 5e3 + 10;
int n, m, k;
vector<int>G[maxn];
int node[maxn];
ll dp[maxn][maxn], backup[maxn];
int maxd[maxn];
void dfs(int u, int fa) {
for(auto it : G[u]) {
if(it == fa) continue;
dfs(it,u);
maxd[u] = max(maxd[u],maxd[it]+1);
memcpy(backup,dp[u],sizeof(backup));
for(int i = 0; i <= maxd[u] && i <= k; ++ i)
for(int j = 0; j <= maxd[it] && j <= k; ++ j)
if(i + j + 1 <= k) {
backup[max(i,j+1)] = (backup[max(i,j+1)] + dp[u][i] * dp[it][j] % mod) % mod;
}
for(int i = 0; i <= maxd[u]; ++ i)
backup[i+1] = (backup[i+1] + dp[it][i]) % mod;
for(int i = 0; i <= k; ++ i)
dp[u][i] = backup[i];
}
if(node[u]) {
dp[u][0] = 1;
for(int i = 1; i <= k; ++ i)
dp[u][i] = dp[u][i] * 2 % mod;
}
}
int main() {
cin >> n >> m >> k;
for(int i = 0; i < n - 1; ++ i) {
int l, r;
cin >> l >> r;
G[l].push_back(r);
G[r].push_back(l);
}
for(int i = 0 ; i < m; ++ i) {
int x;
cin >> x;
node[x] = 1;
}
dfs(1,0);
ll ans = 0;
for(int i = 0; i <= maxd[1]; ++ i)
ans = (ans + dp[1][i]) % mod;
cout << ans << endl;
return 0;
}