BZOJ传送门
洛谷传送门
解析:
首先这道题 O ( n k ( n − k ) ) O(nk(n-k)) O(nk(n−k))的暴力就已经可以过了(而且不用开O2)。(据说这个上界不容易被卡满)
而标算不开 O 2 O2 O2过不去,开了还是跑不过暴力。。。
先来说说暴力:
直接考虑DP出每个点作为第 k k k大的方案数有多少,然后算就行了。
由于我们只需要考虑第 k k k大,这里可以剪一剪枝,最多只需要跑 O ( n − k ) O(n-k) O(n−k)次,每一次可以直接采用树上背包来在每个儿子转移父亲的时候来一个 O ( k ) O(k) O(k)的加就行了,而有 O ( n ) O(n) O(n)个转移,所以总的复杂度为 O ( n k ( n − k ) ) O(nk(n-k)) O(nk(n−k))。
被暴力踩爆的标算:
这个标算的常数。。。算了已经不想吐槽了。
先来转化一下问题,记
a
i
a_i
ai为树上联通块第
k
k
k大大于等于
i
i
i的方案数:
A
n
s
=
∑
i
=
1
i
(
a
i
−
a
i
+
1
)
Ans=\sum_{i=1}i(a_{i}-a_{i+1})
Ans=i=1∑i(ai−ai+1)
拆开直接得到: A n s = ∑ i = 1 a i Ans=\sum_{i=1}a_i Ans=i=1∑ai
于是,等价地改变一下 a i a_i ai的意义:联通块中大于等于 i i i的权值不少于 k k k个点的方案数。
转回树上背包。令 f u , i , j f_{u,i,j} fu,i,j表示在 u u u的子树中形成的联通块,包含 u u u,大于等于 i i i的个数为 j j j的方案数。这个就可以直接做 O ( n k 2 ) O(nk^2) O(nk2)的背包了。
考虑将第三维转化成生成函数的形式: F u , i = ∑ j = 0 f u , i , j x j F_{u,i}=\sum\limits_{j=0}f_{u,i,j}x^j Fu,i=j=0∑fu,i,jxj
发现这个转移是一个多项式卷积: F u , i = ( ∏ v ∈ s u b t r e e ( u ) ( F v , i + 1 ) ) ∗ { 1 i > d u x i ≤ d u F_{u,i}=(\prod_{v\in subtree(u)}(F_{v,i}+1))*\left\{\begin{aligned}1 &&i>d_u\\x&&i\leq d_u\end{aligned}\right. Fu,i=(v∈subtree(u)∏(Fv,i+1))∗{1xi>dui≤du
设 G u , i G_{u,i} Gu,i是 u u u子树内部所有 F v , i F_{v,i} Fv,i之和,那么 ∑ j = 1 W G 1 , j \sum\limits_{j=1}^{W}G_{1,j} j=1∑WG1,j的 k k k次项之后的所有系数之和就是我们要的答案。
直接计算复杂度为 O ( n 4 ) O(n^4) O(n4), F F T FFT FFT可以优化至 O ( n 3 log n ) O(n^3\log n) O(n3logn)甚至没有暴力优秀。
但是我们为什么要求这个多项式呢?
注意答案最后是个 n n n次多项式,转拉格朗日插值,我们考虑插 n + 1 n+1 n+1个值进去。转成点值,这样单点乘法和加法就是 O ( 1 ) O(1) O(1)的了,而且一段区间的乘的是相同的值,于是可以用线段树来维护。
令一点的状态表示为 ( f , g ) (f,g) (f,g),分别表示当前点值和子树点值和(线段树上),于是我们需要维护的东西就是这些:
- ( f , g ) → ( 1 , 0 ) (f,g)\rightarrow(1,0) (f,g)→(1,0) \\初始化,线段树整体覆盖
- ( f , g ) → ( f ( 1 + f v ) , g + g v ) (f,g)\rightarrow(f(1+f_v),g+g_v) (f,g)→(f(1+fv),g+gv)\\线段树合并
- ( f , g ) → ( f x 0 , g ) (f,g)\rightarrow(fx_0,g) (f,g)→(fx0,g)\\线段树区间修改
- ( f , g ) → ( f , g + f ) (f,g)\rightarrow(f,g+f) (f,g)→(f,g+f)\\线段树整体修改
本来这个东西已经可以用矩阵来维护了,但是矩阵常数太大,我们可能需要换一种方法。
注意到改变的形式只有可能是这样: ( f , g ) → ( a f + b , g + c f + d ) (f,g)\rightarrow(af+b,g+cf+d) (f,g)→(af+b,g+cf+d)
于是考虑利用 ( a , b , c , d ) (a,b,c,d) (a,b,c,d)四元组来表示改变的操作。(所以常数巨大)
在合并线段树的时候,一路下放标记,直到一个没有左右儿子的区间。
这样,没有tag的那一边就是定值,另一边的函数直接乘上这个定值就行了。
最后利用重心拉格朗日插值还原系数,求出答案。
代码(标算):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define uint unsigned int
#define re register
#define gc get_char
#define cs const
namespace IO{
inline char get_char(){
static cs int Rlen=1<<20|1;
static char buf[Rlen],*p1,*p2;
return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
re char c;
while(!isdigit(c=gc()));re int num=c^48;
while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48);
return num;
}
}
using namespace IO;
cs uint mod=64123;
inline uint add(uint a,uint b){return a+b>=mod?a+b-mod:a+b;}
inline uint dec(uint a,uint b){return a<b?a-b+mod:a-b;}
inline uint mul(uint a,uint b){return a*b%mod;}
cs int N=1670,B=50;
int n,k,W;
int d[N];
int last[N],nxt[N<<1],to[N<<1],ecnt;
inline void addedge(int u,int v){
nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v;
nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u;
}
int rt[N],lc[N*B],rc[N*B],sta[N*B],top;
struct data{
uint a,b,c,d;
data():a(1),b(0),c(0),d(0){}
data(cs uint &_a,cs uint &_b,cs uint &_c,cs uint &_d):a(_a),b(_b),c(_c),d(_d){}
friend data operator*(cs data &l,cs data &r){
return data(
mul(l.a,r.a),
add(mul(l.b,r.a),r.b),
add(mul(l.a,r.c),l.c),
add(mul(l.b,r.c),add(l.d,r.d))
);
}
friend data operator*=(data &l,cs data &r){
return l=l*r;
}
}val[N*B];
inline int newnode(){
static int tot=0;
int now=top?sta[top--]:++tot;
lc[now]=rc[now]=0;
return now;
}
inline void del(int &x){
if(!x)return ;
del(lc[x]);del(rc[x]);
sta[++top]=x;
val[x]=data();
x=0;
}
inline void pushdown(int k){
if(!lc[k])lc[k]=newnode();
if(!rc[k])rc[k]=newnode();
val[lc[k]]*=val[k];
val[rc[k]]*=val[k];
val[k]=data();
}
inline int merge(int &x,int &y){
if(!x||!y)return x|y;
if(!lc[x]&&!rc[x])swap(x,y);
if(!lc[y]&&!rc[y]){
val[x]*=data(val[y].b,0,0,0);
val[x]*=data(1,0,0,val[y].d);
return x;
}
pushdown(x),pushdown(y);
lc[x]=merge(lc[x],lc[y]);
rc[x]=merge(rc[x],rc[y]);
return x;
}
inline int query(int k,int l,int r){
if(l==r)return val[k].d;
int mid=(l+r)>>1;
pushdown(k);
return add(query(lc[k],l,mid),query(rc[k],mid+1,r));
}
inline void update(int &k,int l,int r,cs int &ql,cs int &qr,cs data &val){
if(!k)k=newnode();
if(ql<=l&&r<=qr){::val[k]*=val;return ;}
int mid=(l+r)>>1;
pushdown(k);
if(ql<=mid)update(lc[k],l,mid,ql,qr,val);
if(mid<qr)update(rc[k],mid+1,r,ql,qr,val);
}
void dfs(int u,int fa,int x_0){
update(rt[u],1,W,1,W,data(0,1,0,0));
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]])
if(v^fa){
dfs(v,u,x_0);
merge(rt[u],rt[v]);
del(rt[v]);
}
update(rt[u],1,W,1,d[u],data(x_0,0,0,0));
update(rt[u],1,W,1,W,data(1,0,1,0));
update(rt[u],1,W,1,W,data(1,1,0,0));
}
inline void dec(int *a,int *b,int x_0){
static int tmp[N];
memcpy(tmp,a,sizeof(int)*(n+2));
for(int re i=n+1;i;--i){
b[i-1]=tmp[i];
tmp[i-1]=add(tmp[i-1],mul(x_0,tmp[i]));
}
}
int g[N],f[N],y[N];
inline int lagrange(){
static int inv[N];
int ans=0;inv[0]=inv[1]=1;
for(int re i=2;i<=n;++i)
inv[i]=mul(mod-mod/i,inv[mod%i]);
g[0]=1;
for(int re i=n+1;i;--i)
for(int re j=n+1;~j;--j){
g[j]=mul(mod-i,g[j]);
if(j)g[j]=add(g[j],g[j-1]);
}
for(int re i=1;i<=n+1;++i){
dec(g,f,i);
int res=0;
for(int re j=k;j<=n;++j)res=add(res,f[j]);
for(int re j=1;j<=n+1;++j)if(i^j){
if(j<i)res=mul(res,inv[i-j]);
else res=mul(res,mod-inv[j-i]);
}
res=mul(res,y[i]);
ans=add(ans,res);
}
return ans;
}
signed main(){
n=getint();k=getint();W=getint();
for(int re i=1;i<=n;++i)d[i]=getint();
for(int re i=1,u,v;i<n;++i){
u=getint(),v=getint();
addedge(u,v);
}
for(int re i=1;i<=n+1;++i){
dfs(1,0,i);
y[i]=query(rt[1],1,W);
del(rt[1]);
}
cout<<lagrange();
return 0;
}
代码(暴力):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc get_char
#define cs const
namespace IO{
inline char get_char(){
static cs int Rlen=1<<20|1;
static char buf[Rlen],*p1,*p2;
return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
re char c;
while(!isdigit(c=gc()));re int num=c^48;
while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48);
return num;
}
}
using namespace IO;
cs int mod=64123;
cs int N=1670;
int last[N],nxt[N<<1],to[N<<1],ecnt;
inline void addedge(int u,int v){
nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v;
nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u;
}
struct node{
int x,y;
friend bool operator<(cs node &a,cs node &b){
return a.x<b.x||(a.x==b.x&&a.y<b.y);
}
}a[N];
int d[N],f[N][N];
bool v[N];
int n,k,w;
void dfs(int u,int fa){
for(int re i=1;i+v[u]<=k;++i)
f[u][i+v[u]]=f[fa][i];
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]])
if(v^fa)dfs(v,u);
for(int re i=1;i<=k;++i){
f[fa][i]+=f[u][i];
if(f[fa][i]>=mod)f[fa][i]-=mod;
}
}
int ans;
signed main(){
n=getint();k=getint();w=getint();
for(int re i=1;i<=n;++i)d[i]=a[i].x=getint(),a[i].y=i;
for(int re i=1;i<n;++i)addedge(getint(),getint());
sort(a+1,a+n+1);
for(int re i=1,cnt,u=a[i].y;i<=n;u=a[++i].y){
cnt=0;
for(int re j=1;j<=n;++j)cnt+=(v[j]=d[j]>d[u]||(d[j]==d[u]&&j>=u));
if(cnt<k)break;
f[u][1]=1;
for(int re i=2;i<=n;++i)f[u][i]=0;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]])dfs(v,u);
ans=(ans+d[u]*f[u][k])%mod;
}
cout<<ans<<"\n";
return 0;
}