题解:
一个显然的做法是枚举每个点计算贡献,把大于他的记为1,小于他的记为0,问题就转化为树上联通块大小等于k的个数。
稍微转化一下,我们统计树上联通块第 k k 大大等于的个数,不妨记为 ai a i ,那么:
我们发现,其实
考虑如何求
ai
a
i
,我们做一个等价转化,变为统计
≥i
≥
i
的个数有
≥k
≥
k
个的联通块的方案数。 考虑树上背包的做法:
记
fp,i,j
f
p
,
i
,
j
表示
p
p
的子树中大等于的个数为
j
j
的方案数。
将第三维用生成函数的形式表达,则:
有转移:
注意一段区间 i i 的多项式相同,用线段树合并来维护这个多项式的转移(即对应位置合并的时候要相乘),我们就可以在的时间内计算出所有 F F 。
还可以更优,瓶颈在于多项式乘法,考虑拉格朗日插值:
记,乘法变为
O(1)
O
(
1
)
,我们计算
n+1
n
+
1
个点值,即可在
O(n2logW)
O
(
n
2
log
W
)
的时间内计算出
Fp
F
p
的
n
n
个点值。
不过插值还原是的,每个点再还原一次时间复杂度依然是 O(n3) O ( n 3 ) ,注意到点值加法依然合法,我们可以多维护在线段树上多维护一个 sum s u m 表示原子树所有点值的和,线段树合并时对应位置相加即可。
那么一个点的状态为
(f,g)
(
f
,
g
)
,表示当前的点值,子树点值和。
现在的转移即为:
(f,g)=(1,0) ( f , g ) = ( 1 , 0 ) //初始化,线段树覆盖
(f,g)→(f(1+fv),g+gv) ( f , g ) → ( f ( 1 + f v ) , g + g v ) // 线段树合并
(f,g)→(fx0,g) ( f , g ) → ( f x 0 , g ) //线段树修改区间
- (f,g)→(f,g+f) ( f , g ) → ( f , g + f ) //线段树整体修改
注意修改其实很繁琐,我们统一用矩阵来表示转移。
不过矩阵常数大,我们直接用函数的复合来维护
tag
t
a
g
:
f
f
只会变为,
g
g
只会变为,我们维护
a,b,c,d
a
,
b
,
c
,
d
,然后就变成函数复合了,常数要小很多。
具体线段树合并的方法:
注意到函数复合跟顺序有关,我们合并
(x,y)
(
x
,
y
)
的时候一路把
tag
t
a
g
下放到
(x,y)
(
x
,
y
)
中的一个没有左右儿子的区间(实际上这跟线段树合并的过程类似)。
那么有一边的函数为定值(下方没有 tag t a g ),另外一边的函数直接乘上这个定值即可。
具体拉格朗日插值的方法:
注意式子
记 w=Π(x−xi) w = Π ( x − x i )
则公式变为:
∑ni=1w(x−xi)ai ∑ i = 1 n w ( x − x i ) a i
ai=yiΠ(xi−xj) a i = y i Π ( x i − x j )
O(n2) O ( n 2 ) 解决,实际上这就是重心拉格朗日插值。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int RLEN=1<<18|1;
inline char nc() {
static char ibuf[RLEN],*ib,*ob;
(ib==ob) && (ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
return (ib==ob) ? -1 : *ib++;
}
inline int rd() {
char ch=nc(); int i=0,f=1;
while(!isdigit(ch)) {if(ch=='-')f=-1; ch=nc();}
while(isdigit(ch)) {i=(i<<1)+(i<<3)+ch-'0'; ch=nc();}
return i*f;
}
const int N=1700, mod=64123;
inline int add(int x,int y) {return (x+y>=mod) ? (x+y-mod) : (x+y);}
inline int dec(int x,int y) {return (x-y<0) ? (x-y+mod) : (x-y);}
inline int mul(int x,int y) {return (LL)x*y%mod;}
inline int power(int a,int b,int rs=1) {for(;b;b>>=1, a=mul(a,a)) if(b&1) rs=mul(rs,a); return rs;}
struct data {
int a,b,c,d;
//(f,g)
//-> (a1*f+b1, c1*f+d1+g)
//-> ( (a1*a2)*f + a2*b1+b2, (a1*c2+c1)*f + (b1*c2 + d1 + d2) + g)
data() : a(1),b(0),c(0),d(0) {}
data(int a,int b,int c,int d) : a(a),b(b),c(c),d(d) {}
friend inline data operator *(const data &a,const data &b) {
return data(mul(a.a,b.a),
add(mul(b.a,a.b),b.b),
add(mul(a.a,b.c),a.c),
add(mul(a.b,b.c),add(a.d,b.d)));
}
};
int n, k, W, d[N], yc[N], inv[N];
vector <int> edge[N];
int lc[N*50], rc[N*50], rt[N], stk[N*50], tl, tot;
data tag[N*50];
inline int newnode() {
int t;
if(tl) t=stk[tl--];
else t=++tot;
tag[t]=data();
lc[t]=rc[t]=0;
return t;
}
inline void pushdown(int k) {
if(!lc[k]) lc[k]=newnode();
if(!rc[k]) rc[k]=newnode();
tag[lc[k]]=tag[lc[k]]*tag[k];
tag[rc[k]]=tag[rc[k]]*tag[k];
tag[k]=data();
}
inline void getval(int x,int l,int r,int i) {
if(l==r) {yc[i]=add(yc[i],tag[x].d); return;}
pushdown(x);
int mid=(l+r)>>1;
getval(lc[x],l,mid,i);
getval(rc[x],mid+1,r,i);
}
inline void del(int &k) {
if(!k) return;
del(lc[k]);
del(rc[k]);
stk[++tl]=k;
k=0;
}
inline void modify(int &k,int l,int r,int L,int R,data tg) {
if(!k) k=newnode();
if(L<=l&&r<=R) {tag[k]=tag[k]*tg; return;}
pushdown(k); int mid=(l+r)>>1;
if(R<=mid) modify(lc[k],l,mid,L,R,tg);
else if(L>mid) modify(rc[k],mid+1,r,L,R,tg);
else modify(lc[k],l,mid,L,R,tg), modify(rc[k],mid+1,r,L,R,tg);
}
inline void merge(int &x,int &y) {
if(!x) swap(x,y);
if(!y) return;
if(lc[x]==0 && rc[x]==0) swap(x,y);
if(lc[y]==0 && rc[y]==0) {
tag[x].a=mul(tag[x].a,tag[y].b);
tag[x].b=mul(tag[x].b,tag[y].b);
tag[x].d=add(tag[x].d,tag[y].d);
return;
}
pushdown(x); pushdown(y);
merge(lc[x],lc[y]); merge(rc[x],rc[y]);
}
inline void lgip(int x,int fa,int x0) {
modify(rt[x],1,W,1,W,data(0,1,0,0));
for(auto v:edge[x])
if(v!=fa) {
lgip(v,x,x0);
merge(rt[x],rt[v]);
del(rt[v]);
}
if(d[x]) modify(rt[x],1,W,1,d[x],data(x0,0,0,0));
modify(rt[x],1,W,1,W,data(1,0,1,0));
modify(rt[x],1,W,1,W,data(1,1,0,0));
}
inline void dec(int *a,int *b,int x0) {
static int tmp[N];
for(int i=0;i<=n+1;i++) tmp[i]=a[i];
for(int i=n+1;i>=1;i--) {
b[i-1]=tmp[i];
tmp[i-1]=add(tmp[i-1],mul(x0,tmp[i]));
}
}
inline int getans() {
static int g[N], f[N], ans;
g[0]=1;
for(int i=n+1;i>=1;--i)
for(int j=n+1;j>=0;j--) {
g[j]=mul(mod-i,g[j]);
if(j) g[j]=add(g[j],g[j-1]);
}
for(int i=1;i<=n+1;i++) {
dec(g,f,i); int rs=0;
for(int j=k;j<=n;j++) rs=add(rs,f[j]);
for(int j=1;j<=n+1;j++) if(i!=j) {
if(j<i) rs=mul(rs,inv[i-j]);
else rs=mul(rs,mod-inv[j-i]);
}
rs=mul(rs,yc[i]); ans=add(ans,rs);
}
return ans;
}
int main() {
n=rd(), k=rd(), W=rd();
for(int i=1;i<=n;i++) d[i]=rd(), inv[i]=power(i,mod-2);
for(int i=1;i<n;i++) {
int x=rd(), y=rd();
edge[x].push_back(y);
edge[y].push_back(x);
}
for(int i=1;i<=n+1;i++) {
lgip(1,0,i);
getval(rt[1],1,W,i);
del(rt[1]);
}
cout<<getans();
}