题目描述
小 F 的生日还有一个多月,大 F 早早地准备起了礼物。
“你想要什么礼物呀?嗯…要不要好吃的?”
“才不要呢,我想要好看的花,永远不会凋谢的花。”
小 F 和大 F 一起生活的国家—— Fairy 国,可以抽象成一棵 N N N 个节点的树,每个节点就是一个城市,编号为 1 … N 1\ldots N 1…N。
大 F 要游历各个城市,为心爱的小 F 寻找好看的花。
Fairy 国的每个城市都有一座山,山上有恰好一朵永远不会凋谢的花,编号为 i i i 的城市的花的美丽值为 B i B_i Bi。大 F 要在 N N N 个城市中选出恰好 M M M 个,并摘来这 M M M 个城市中的 M M M 朵花送给小F。可是呢,如果树上的一条边连接的两个城市的花都被摘去,这条边就会塌陷,Fairy 国就会陷入分裂,大 F 作为一个善良的人,不希望这样的情况发生。所以,一种摘法合法,当且仅当对于每条边,这条边相连的两个节点的花不被同时摘去。
大 F 希望小 F 快乐,小 F 的快乐程度将是摘来的
M
M
M 朵花的美丽程度的积。大 F 今天闲着没事,想要求出对于所有合法的摘法,小 F 的快乐程度之和对
998244353
998244353
998244353 取模的结果。
数据范围
对于所有数据,保证 1 ≤ M ≤ N ≤ 8 × 1 0 4 1 \le M \le N \le 8 \times 10^4 1≤M≤N≤8×104, 0 ≤ B i < 998244353 0 \le B_i < 998244353 0≤Bi<998244353。
下表为各个 Subtask 的额外限制与得分,空格表示该项无额外限制。你只有通过一个 Subtask 的所有数据才能得到该 Subtask 的分。
Subtask 编号 | N N N | M M M | 特殊限制 | 分值 |
---|---|---|---|---|
1 | ≤ 500 \le 500 ≤500 | 7 | ||
2 | ≤ 4000 \le 4000 ≤4000 | 15 | ||
3 | ≤ 10 \le 10 ≤10 | 15 | ||
4 | ∀ 1 ≤ i < N \forall 1\le i < N ∀1≤i<N,读入的第 i i i 条边是 ( i , i + 1 ) (i,i+1) (i,i+1) | 18 | ||
5 | ∀ 1 ≤ i < N \forall 1\le i < N ∀1≤i<N,读入的第 i i i 条边是 ( 1 , i + 1 ) (1,i+1) (1,i+1) | 20 | ||
6 | 25 |
题解
考虑暴力dp,设 f i , 0 / 1 , j f_{i,0/1,j} fi,0/1,j 表示 i i i 子树内取了 j j j 个点, i i i 取/不取的价值之和
如果是一条链的话,可以分治+ N t t Ntt Ntt实现
设 F l , r , 0 / 1 , 0 / 1 ( x ) F_{l,r,0/1,0/1}(x) Fl,r,0/1,0/1(x) 表示 [ l , r ] [l,r] [l,r] 区间, l l l 取/不取, r r r 取/不取,摘了 i i i 朵的答案在 x i x^i xi 的系数上
考虑一棵树,将其树链剖分,一条重链上的每个点要先将其轻儿子的信息也用分治+ N t t Ntt Ntt 的方法存到每个点上,再将这条重链用上述方法操作即可
效率不会证,貌似 O ( n l o g 2 n / n l o g 3 n ) O(nlog^2n\ /\ nlog^3n) O(nlog2n / nlog3n)
code
#include <bits/stdc++.h>
#define E vector<int>
#define mid ((l+r)>>1)
using namespace std;
const int N=2e5+5,P=998244353;
int n,m,w[N],sz[N],son[N],fa[N],t,tt,hd[N],V[N*2];
int nx[N*2],T[N],b[N],B,re[N*8],S[2]={3,(P+1)/3},p,G[N*8],H[N*8];
E f[2][N],g[2][N],I,M,W;struct O{E a[2][2];}tmp,rs;
int K(int x,int y){
int z=1;
for (;y;y>>=1,x=1ll*x*x%P)
if (y&1) z=1ll*z*x%P;
return z;
}
void add(int u,int v){
nx[++tt]=hd[u];V[hd[u]=tt]=v;
}
void pre(int l){
for (t=1,p=0;t<l;t<<=1,p++);
for (int i=0;i<t;i++)
re[i]=(re[i>>1]>>1)|((i&1)<<(p-1));
}
void Ntt(int *s,bool o){
for (int i=0;i<t;i++)
if (i<re[i]) swap(s[i],s[re[i]]);
for (int wn,i=1;i<t;i<<=1){
wn=K(S[o],(P-1)/(i<<1));
for (int x,y,j=0;j<t;j+=(i<<1))
for (int w=1,k=0;k<i;k++,w=1ll*w*wn%P)
x=s[j+k],y=1ll*w*s[i+j+k]%P,
s[j+k]=(x+y)%P,s[i+j+k]=(x-y+P)%P;
}
if (o)
for (int i=0,v=K(t,P-2);i<t;i++)
s[i]=1ll*v*s[i]%P;
}
E by(E a,E b){
int la=a.size(),lb=b.size();pre(la+lb);
for (int i=0;i<la;i++) G[i]=a[i];
for (int i=0;i<lb;i++) H[i]=b[i];
Ntt(G,0);Ntt(H,0);
for (int i=0;i<t;i++) G[i]=1ll*G[i]*H[i]%P;
Ntt(G,1);W.clear();
for (int i=0;i<la+lb-1;i++) W.push_back(G[i]);
for (int i=0;i<t;i++) G[i]=H[i]=0;
return W;
}
E ad(E a,E b){
int la=a.size(),lb=b.size(),lc=max(la,lb);
for (int i=0;i<lc;i++)
G[i]=((i<la?a[i]:0)+(i<lb?b[i]:0))%P;
W.clear();
for (int i=0;i<lc;i++) W.push_back(G[i]),G[i]=0;
return W;
}
E div(int l,int r,int o){
if (l==r) return f[o][b[l]];
return by(div(l,mid,o),div(mid+1,r,o));
}
O solve(int l,int r){
if (l==r){
for (int i=0;i<2;i++)
tmp.a[i][i]=g[i][T[l]],
tmp.a[i][!i]={0};
return tmp;
}
O L=solve(l,mid),R=solve(mid+1,r);
for (int i=0;i<2;i++)
for (int j=0;j<2;j++)
tmp.a[i][j]=ad(by(L.a[i][0],ad(R.a[0][j],R.a[1][j])),by(L.a[i][1],R.a[0][j]));
return tmp;
}
void work(int x){
tt=0;int u=x;
for (;x;x=son[x]){
T[++tt]=x;B=0;M[1]=w[x];
for (int i=hd[x];i;i=nx[i])
if (V[i]!=fa[x] && V[i]!=son[x])
b[++B]=V[i];
g[1][x]=by(B?div(1,B,0):I,M);
g[0][x]=B?div(1,B,1):I;
}
rs=solve(1,tt);
f[0][u]=ad(rs.a[0][0],rs.a[0][1]);
f[1][u]=ad(f[0][u],ad(rs.a[1][0],rs.a[1][1]));
}
void dfs(int u,int fr){
sz[u]=1;fa[u]=fr;
for (int i=hd[u];i;i=nx[i])
if (V[i]!=fr){
dfs(V[i],u),sz[u]+=sz[V[i]];
if (sz[V[i]]>sz[son[u]]) son[u]=V[i];
}
for (int i=hd[u];i;i=nx[i])
if (V[i]!=fr && V[i]!=son[u]) work(V[i]);
}
int main(){
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&w[i]);
for (int x,y,i=1;i<n;i++)
scanf("%d%d",&x,&y),
add(x,y),add(y,x);
I.push_back(1);M.push_back(0);
M.push_back(0);dfs(1,0);work(1);
if (m<(int)f[1][1].size()) printf("%d\n",f[1][1][m]);
else puts("0");return 0;
}