Description
-
求大小为 n n n的树上所有大小为 k k k的连通块的重心权值之和(如果有两个重心,取编号小的那个)。
-
n ≤ 5 e 4 , k ≤ 500 n\le5e4,k\le500 n≤5e4,k≤500
Solution
-
显然可以得到一种 O ( n k 2 ) O(nk^2) O(nk2)的换根DP,直接枚举子树内选择的点数转移即可。
-
经典套路:如果没有换根DP的部分,直接做大小为 k k k的树形背包的复杂度实际上是 O ( n k ) O(nk) O(nk)的(而不是 O ( n k 2 ) O(nk^2) O(nk2))。
-
证明: 首先合并两个子树 x x x和 y y y的时候,时间复杂度是 m i n ( s z [ x ] , k ) ∗ m i n ( s z [ y ] , k ) min(sz[x],k)*min(sz[y],k) min(sz[x],k)∗min(sz[y],k)的。
- s z [ x ] , s z [ y ] ≥ k sz[x],sz[y]\ge k sz[x],sz[y]≥k时,考虑所有极小的子树使得 s z [ i ] ≥ k sz[i]\ge k sz[i]≥k,那么是缩点后原树上的叶子,最多 n / k n/k n/k个这样的叶子,它们之间合并 n / k n/k n/k次,每一次 k 2 k^2 k2
- s z [ x ] , s z [ y ] < k sz[x],sz[y]<k sz[x],sz[y]<k时,考虑所有极大的子树使得 s z [ i ] < k sz[i]<k sz[i]<k,那么这类合并都在这种子树里面, s z [ i ] < k sz[i]<k sz[i]<k,所以内部的点两两产生贡献,每一个点最多和 k k k个点产生贡献,即 n k nk nk.
- s z [ x ] ≥ k , s z [ y ] < k sz[x]\ge k,sz[y]<k sz[x]≥k,sz[y]<k时,考虑 y y y的子树内的每个节点,产生 k k k的贡献,所以 n k nk nk。
-
既然如此,记录一个 f [ x ] [ i ] f[x][i] f[x][i]和 g [ x ] [ i ] g[x][i] g[x][i]分别表示 x x x子树内连通块大小为 i i i的方案数和重心权值和,做背包DP即可。
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define min(a,b) ((a<b)?a:b)
#define maxn 50005
#define maxm 505
#define ll long long
#define mo 1000000007
using namespace std;
int n,m,i,j,k,a[maxn],sz[maxn];
int em,e[maxn*2],nx[maxn*2],ls[maxn];
ll f[maxn][maxm],g[maxn][maxm],ans;
void read(int &x){
x=0; char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar());
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
}
void insert(int x,int y){
em++; e[em]=y; nx[em]=ls[x]; ls[x]=em;
em++; e[em]=x; nx[em]=ls[y]; ls[y]=em;
}
int d[maxn],fa[maxn];
void doit(int x,int p){
sz[x]=1,f[x][1]=1;
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p) {
int y=e[i];
for(int j=min(m-1,sz[x]);j>=m/2;j--) for(int k=min(min(m-j,m/2),sz[y]);k;k--)
(g[x][j+k]+=g[x][j]*f[y][k])%=mo;
for(int j=min(m-1,sz[y]);j>=m/2;j--) for(int k=min(min(m-j,m/2),sz[x]);k;k--)
(g[x][j+k]+=g[y][j]*f[x][k])%=mo;
for(int j=min(m-1,sz[x]);j;j--) for(int k=min(min(m-j,m/2-((m&1)==0)),sz[y]);k;k--)
(f[x][j+k]+=f[x][j]*f[y][k])%=mo;
sz[x]+=sz[y];
}
if (m%2==0&&p) (g[x][m/2]+=f[x][m/2]*a[min(x,p)])%=mo;
for(int i=m/2+1;i<=m;i++) if (f[x][i])
(g[x][i]+=f[x][i]*a[x])%=mo;
ans+=g[x][m];
}
void bfs(){
int t=0,w=1; d[1]=1;
while (t<w){
int x=d[++t];
for(int i=ls[x];i;i=nx[i]) if (e[i]!=fa[x])
d[++w]=e[i],fa[e[i]]=x;
}
while (w){
int x=d[w--];
doit(x,fa[x]);
}
}
int main(){
// freopen("centroid.in","r",stdin);
// freopen("centroid.out","w",stdout);
read(n),read(m);
for(i=1;i<=n;i++) read(a[i]);
for(i=1;i<n;i++) read(j),read(k),insert(j,k);
bfs();
printf("%lld",ans%mo);
}