题目分析
设 f ( x , i , 0 / 1 , 0 / 1 ) f(x,i,0/1,0/1) f(x,i,0/1,0/1)表示以节点 x x x为根的子树,用掉 i i i个窃听器, x x x上是否有窃听器, x x x是否被窃听的方案数。转移的话枚举 x x x的儿子的子树中用了多少窃听器。
注意到假设我DP过程中考虑过了 x x x子树中 s z ( x ) sz(x) sz(x)个点,那么它们上面放的窃听器不会超过 s z ( x ) sz(x) sz(x)个,利用 s z sz sz限制一下转移时枚举窃听器数量的范围,可以将复杂度做到 O ( n k ) O(nk) O(nk)。
复杂度证明:
- 若需要合并DP信息的两边子树大小都大于等于 K K K,这样的合并最多发生 n K \frac{n}{K} Kn次,单次合并的复杂度是 K 2 K^2 K2
- 若一个子树大于等于 K K K,另一个小于,那么将小于的那个的复杂度均摊在它里面每个节点中,都均摊一个 K K K的复杂度,而两个子树合并后,大小一定大于 K K K了,所以每个节点最多被均摊一次这个复杂度。
- 若两个子树大小都小于 K K K,也是考虑均摊,相当于每个点所在的子树一和一个新点合并,就增加1的复杂度,最多不会增加超过 K K K次。
综上,复杂度是 O ( n k ) O(nk) O(nk)的。
代码
高维数组寻址很慢,所以 压成低位数组 偷偷开个O2
#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
int q=0;char ch=' ';
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
return q;
}
const int mod=1000000007,N=100005;
int n,K,tot;
int sz[N],h[N],ne[N<<1],to[N<<1],f[N][105][2][2],tmp[105][2][2];
int qm(int x) {return x>=mod?x-mod:x;}
void add(int x,int y) {to[++tot]=y,ne[tot]=h[x],h[x]=tot;}
void dfs(int x,int las) {
f[x][0][0][0]=f[x][1][1][0]=1,sz[x]=1;
for(RI i=h[x];i;i=ne[i]) {
if(to[i]==las) continue;
int y=to[i];dfs(y,x);
for(RI j=0;j<=K&&j<=sz[x];++j) {
tmp[j][0][0]=f[x][j][0][0],f[x][j][0][0]=0;
tmp[j][0][1]=f[x][j][0][1],f[x][j][0][1]=0;
tmp[j][1][0]=f[x][j][1][0],f[x][j][1][0]=0;
tmp[j][1][1]=f[x][j][1][1],f[x][j][1][1]=0;
}
for(RI j=0;j<=K&&j<=sz[x];++j)
for(RI k=0;j+k<=K&&k<=sz[y];++k) {
f[x][j+k][0][0]=qm(f[x][j+k][0][0]+1LL*tmp[j][0][0]*f[y][k][0][1]%mod);
f[x][j+k][0][1]=qm(f[x][j+k][0][1]+1LL*tmp[j][0][0]*f[y][k][1][1]%mod);
f[x][j+k][0][1]=qm(f[x][j+k][0][1]+1LL*tmp[j][0][1]*f[y][k][0][1]%mod);
f[x][j+k][0][1]=qm(f[x][j+k][0][1]+1LL*tmp[j][0][1]*f[y][k][1][1]%mod);
f[x][j+k][1][0]=qm(f[x][j+k][1][0]+1LL*tmp[j][1][0]*f[y][k][0][0]%mod);
f[x][j+k][1][0]=qm(f[x][j+k][1][0]+1LL*tmp[j][1][0]*f[y][k][0][1]%mod);
f[x][j+k][1][1]=qm(f[x][j+k][1][1]+1LL*tmp[j][1][0]*f[y][k][1][0]%mod);
f[x][j+k][1][1]=qm(f[x][j+k][1][1]+1LL*tmp[j][1][0]*f[y][k][1][1]%mod);
f[x][j+k][1][1]=qm(f[x][j+k][1][1]+1LL*tmp[j][1][1]*qm(f[y][k][0][0]+
qm(f[y][k][0][1]+qm(f[y][k][1][0]+f[y][k][1][1])))%mod);
}
sz[x]+=sz[y];
}
}
int main()
{
int x,y;
n=read(),K=read();
for(RI i=1;i<n;++i) x=read(),y=read(),add(x,y),add(y,x);
dfs(1,0);
printf("%d\n",qm(f[1][K][1][1]+f[1][K][0][1]));
return 0;
}