这个正解是典型的高分低能啊。
(4个标记的胡乱下传的)线段树优化整体DP。
多项式(暴力)求点值化解卷积。
最后还用一个秀得不行的拉格朗日插值法求系数。
然鹅
O
(
n
2
log
k
)
O(n^2\log k)
O(n2logk)跑不过
O
(
n
2
k
)
O(n^2k)
O(n2k)
花里胡哨。
大佬1
大佬2
还有我那巨丑无比的代码:
PS1:用结构体维护标记下传可能会让代码变得简洁(龟速)?
PS2:模数不支持NTT和FFT时我们也可以用(暴力)求点值规避卷积。
#include<bits/stdc++.h>
#define maxn 2005
#define LL long long
using namespace std;
const LL mod = 64123;
LL n,k,W;
LL info[maxn],Prev[maxn<<1],to[maxn<<1],cnt_e;
inline void Node(LL u,LL v){Prev[++cnt_e]=info[u],info[u]=cnt_e,to[cnt_e]=v;}
struct data
{
LL a,b,c,d;
data(LL a=0,LL b=0,LL c=0,LL d=0):a(a),b(b),c(c),d(d){}
data operator *(const data &B)const
{ return data(a*B.a%mod,(B.a*b%mod+B.b)%mod,(c+B.c*a%mod)%mod,(B.c*b%mod+d+B.d)%mod); }
bool not_e(){ return a!=1 || b!=0 || c!=0 || d!=0; }
void set_e(){ a=1,b=0,c=0,d=0; }
};
struct tree
{
LL lc,rc;
data val;
}tr[maxn*100];
LL rt[maxn],tot,d[maxn];
stack<LL>pt;
LL newnode(LL res=0)
{
if(!pt.empty())
{
res = pt.top();
pt.pop();
}
else res = ++tot;
tr[res] . val . set_e();
return res;
}
void Del(LL &now)
{
if(!now) return;
Del(tr[now].lc),Del(tr[now].rc);
tr[now].val.set_e() , tr[now].lc = tr[now].rc = 0;
pt.push(now);
now = 0;
}
void dt(LL now)
{
if(tr[now].val.not_e())
{
if(!tr[now].lc) tr[now].lc = newnode();
if(!tr[now].rc) tr[now].rc = newnode();
LL lc = tr[now].lc , rc = tr[now].rc;
tr[lc].val = tr[lc].val * tr[now].val;
tr[rc].val = tr[rc].val * tr[now].val;
tr[now].val.set_e();
}
}
void Opt(LL &now,LL l,LL r,LL ql,LL qr,const data &v)
{
if(r < ql || l > qr) return;
if(!now) now = newnode();
if(ql <= l && r <= qr)
{
tr[now].val = tr[now].val * v;
return;
}
if(l == r) return;
LL mid = (l+r) >> 1;
dt(now);
Opt(tr[now].lc,l,mid,ql,qr,v) , Opt(tr[now].rc,mid+1,r,ql,qr,v);
}
LL Merge(LL &x,LL &y)
{
if(!x || !y) return x+y;
if(!tr[x].lc && !tr[x].rc) swap(x,y);
if(!tr[y].lc && !tr[y].rc)
{
tr[x].val = tr[x].val * data(tr[y].val.b,0,0,0);
tr[x].val = tr[x].val * data(1,0,0,tr[y].val.d);
return x;
}
dt(x),dt(y);
tr[x].lc = Merge(tr[x].lc,tr[y].lc);
tr[x].rc = Merge(tr[x].rc,tr[y].rc);
return x;
}
void dfs(LL now,LL ff,LL X)
{
Opt(rt[now],1,W,1,W,data(0,1,0,0));
for(LL i=info[now];i;i=Prev[i])
if(to[i]!=ff)
{
dfs(to[i],now,X);
rt[now]=Merge(rt[now],rt[to[i]]);
Del(rt[to[i]]);
}
Opt(rt[now],1,W,1,d[now],data(X,0,0,0));
Opt(rt[now],1,W,1,W,data(1,0,1,0));
Opt(rt[now],1,W,1,W,data(1,1,0,0));
}
LL Query(LL now,LL l,LL r)
{
if(l == r) return tr[now].val.d;
dt(now);
LL mid = (l+r) >> 1;
return (Query(tr[now].lc,l,mid) + Query(tr[now].rc,mid+1,r)) % mod;
}
LL ans[maxn],inv[maxn]={1,1},invf[maxn]={1,1},spol[maxn],rpol[maxn],zpol[maxn];
int main()
{
cin>>n>>k>>W;
for(LL i=1;i<=n;i++) cin>>d[i];
for(LL i=1;i<n;i++)
{
LL u,v;
cin>>u>>v;
Node(u,v),Node(v,u);
}
for(LL x=1;x<=n+1;x++)
{
dfs(1,0,x);
ans[x] = Query(rt[1],1,W);
Del(rt[1]);
}
for(LL i=2;i<=n+1;i++)
inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod,
invf[i] = 1ll * invf[i-1] * inv[i] % mod;
spol[0] = 1;
for(LL i=1;i<=n+1;i++)
{
for(int j=n+1;j>=0;j--)
spol[j+1] = (spol[j+1] + spol[j]) % mod,
spol[j] = spol[j] * (mod-i) % mod;
}
for(LL i=1;i<=n+1;i++)
{
memcpy(rpol,spol,sizeof spol);
LL cs = ans[i] * invf[i-1] % mod * (((n+1-i)&1) ? mod-1 : 1) * invf[n+1-i] % mod;
for(LL j=0;j<=n;j++)
rpol[j] = rpol[j] * (mod-inv[i]) % mod,
rpol[j+1] = (mod + rpol[j+1] - rpol[j]) % mod,
zpol[j] = (zpol[j] + rpol[j] * cs % mod) % mod;
}
LL ret = 0;
for(LL i=k;i<=n;i++)
ret = (ret + zpol[i]) % mod;
cout<<(ret+mod)%mod;
}