解题思路:
第一个想法是枚举第 k k 大的值,把大于的记为1,小于的记为0,问题就转化为树上联通块大小等于的个数。
稍微转化一下,我们统计树上联通块第
k
k
大大等于的个数,不妨记为
ai
a
i
,那么
而因为这样计算每个大等于
i
i
的方案在到
ai
a
i
中都会被算一次,恰好被计算了
i
i
次,所以
求
ai
a
i
也可以等价转化成求联通块中大等于
i
i
大等于个的方案数。
设
fu,i,j
f
u
,
i
,
j
表示表示
u
u
的子树中包含且大等于
i
i
的个数为的联通块个数,那么直接做树上背包就是
O(n3)
O
(
n
3
)
的复杂度,卡常据说可以过。
将第三维用生成函数的形式表达,即:
Fu,i=∑j=0nfu,i,jxj
F
u
,
i
=
∑
j
=
0
n
f
u
,
i
,
j
x
j
。
考虑普通的树上背包dp用多项式乘法形式表现,即有转移:
再设 Gu,i G u , i 是 u u 子树内之和,那么答案就是 ∑i=1WG1,i ∑ i = 1 W G 1 , i 中 k k 次项之后所有系数之和。
如果暴力维护多项式乘法是的。
注意到答案最后是个 n n 次多项式,考虑插个 x x 的值进去,转化成点值来算,这样乘法就是的,而且转移 F F 时连续一段乘的是相同的值,所以可以用线段树维护。
转移流程大概如下:
- (f,g)=(1,0) ( f , g ) = ( 1 , 0 ) //初始化,线段树整体覆盖
- (f,g)=(f(fv+1),g+gv) ( f , g ) = ( f ( f v + 1 ) , g + g v ) //线段树合并
- (f,g)=(fx0,g) ( f , g ) = ( f x 0 , g ) //线段树区间修改
- (f,g)=(f,g+f) ( f , g ) = ( f , g + f ) //线段树整体修改
最后 f,g f , g 会变为 af+b,cf+d a f + b , c f + d 的形式,类似线段树维护乘法和加法,我们维护 a,b,c,d a , b , c , d 四个值的转移。
最后 ∑i=1WG1,i ∑ i = 1 W G 1 , i 也是个多项式,所以我们只需要把点值对应项相加后求一遍系数就行了。
时间复杂度是 O(n2logW) O ( n 2 l o g W ) 的,然而没有比暴力快……汗
#include<bits/stdc++.h>
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();c!='-'&&(c<'0'||c>'9');c=getchar());
if(c=='-')f=-1,c=getchar();
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=2005,mod=64123;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
int Pow(int x,int y)
{
int res=1;
for(;y;y>>=1,x=mul(x,x))
if(y&1)res=mul(res,x);
return res;
}
int n,k,W,d[N];vector<int>e[N];
int tot,pool_top,pool[N*50],rt[N],lc[N*50],rc[N*50];
int inv[N],yc[N],c[N],g[N],f[N];
struct data
{
int a,b,c,d;
data():a(1),b(0),c(0),d(0){}
data(int _a,int _b,int _c,int _d):a(_a),b(_b),c(_c),d(_d){}
inline friend data operator * (const data &a,const data &b)
{
return data(mul(a.a,b.a),
add(mul(b.a,a.b),b.b),
add(mul(b.c,a.a),a.c),
add(mul(b.c,a.b),add(a.d,b.d)));
}
}tag[N*50];
inline int newnode()
{
int x=pool_top?pool[pool_top--]:++tot;
tag[x]=data(),lc[x]=rc[x]=0;
return x;
}
void del(int &x)
{
if(!x)return;
del(lc[x]),del(rc[x]);
pool[++pool_top]=x,x=0;
}
void pushdown(int x)
{
if(!lc[x])lc[x]=newnode();
if(!rc[x])rc[x]=newnode();
tag[lc[x]]=tag[lc[x]]*tag[x];
tag[rc[x]]=tag[rc[x]]*tag[x];
tag[x]=data();
}
void modify(int &k,int l,int r,int x,int y,data tg)
{
if(!k)k=newnode();
if(x<=l&&r<=y){tag[k]=tag[k]*tg;return;}
pushdown(k);int mid=l+r>>1;
if(x<=mid)modify(lc[k],l,mid,x,y,tg);
if(y>mid)modify(rc[k],mid+1,r,x,y,tg);
}
void merge(int &x,int &y)
{
if(!x)swap(x,y);
if(!y)return;
if(!lc[x]&&!rc[x])swap(x,y);
if(!lc[y]&&!rc[y])
{
tag[x].a=mul(tag[x].a,tag[y].b);
tag[x].b=mul(tag[x].b,tag[y].b);
tag[x].d=add(tag[x].d,tag[y].d);
return;
}
pushdown(x),pushdown(y);
merge(lc[x],lc[y]),merge(rc[x],rc[y]);
}
void Init(int u,int fa,int x0)
{
modify(rt[u],1,W,1,W,data(0,1,0,0));
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==fa)continue;
Init(v,u,x0);
merge(rt[u],rt[v]);
del(rt[v]);
}
if(d[u])modify(rt[u],1,W,1,d[u],data(x0,0,0,0));
modify(rt[u],1,W,1,W,data(1,0,1,0));
modify(rt[u],1,W,1,W,data(1,1,0,0));
}
void get_val(int k,int l,int r,int i)
{
if(l==r){yc[i]=add(yc[i],tag[k].d);return;}
pushdown(k);
int mid=l+r>>1;
get_val(lc[k],l,mid,i),get_val(rc[k],mid+1,r,i);
}
void div(int *a,int *b,int x0)
{
for(int i=0;i<=n+1;i++)c[i]=a[i];
for(int i=n+1;i>=1;i--)
{
b[i-1]=c[i];
c[i-1]=add(c[i-1],mul(c[i],x0)),c[i]=0;
}
}
int get_ans()
{
int ans=0;
for(int i=1;i<=n+1;i++)inv[i]=Pow(i,mod-2);
g[0]=1;
for(int i=1;i<=n+1;i++)
for(int j=n+1;j>=0;j--)
{
g[j]=mul(g[j],mod-i);
if(j)g[j]=add(g[j],g[j-1]);
}
for(int i=1;i<=n+1;i++)
{
div(g,f,i);int res=0;
for(int j=k;j<=n;j++)res=add(res,f[j]);
for(int j=1;j<=n+1;j++) if(i!=j)
res=(i>j?mul(res,inv[i-j]):mul(res,mod-inv[j-i]));
res=mul(res,yc[i]),ans=add(ans,res);
}
return ans;
}
int main()
{
//freopen("lx.in","r",stdin);
n=getint(),k=getint(),W=getint();
for(int i=1;i<=n;i++)d[i]=getint();
for(int i=1;i<n;i++)
{
int x=getint(),y=getint();
e[x].push_back(y),e[y].push_back(x);
}
for(int i=1;i<=n+1;i++)
{
Init(1,0,i);
get_val(rt[1],1,W,i);
del(rt[1]);
}
cout<<get_ans();
return 0;
}