Solution S o l u t i o n
fu,0/1,i
f
u
,
0
/
1
,
i
表示
u
u
子树中选了个点,
u
u
选不选的答案。
转移显然就是一个卷积的形式。
考虑重链剖分。
先把轻儿子卷到根,这样只需要考虑重链。
再考虑重链上的DP。
考虑头尾选不选。这也是个卷积,可以分治+FFT。
#include <bits/stdc++.h>
#define show(x) cerr << #x << " = " << x << endl
using namespace std;
const int MOD = 998244353;
const int N = 202020;
typedef long long ll;
typedef vector<int> poly;
struct arr {
poly v0, v1;
arr(void) {}
arr(poly _v0, poly _v1): v0(_v0), v1(_v1) {}
inline int size(void) const {
return v0.size();
}
inline bool operator <(const arr &b) const {
return b.size() < size();
}
};
struct qua {
poly v00, v01, v10, v11;
};
poly blank;
inline char get(void) {
static char buf[100000], *S = buf, *T = buf;
if (S == T) {
T = (S = buf) + fread(buf, 1, 100000, stdin);
if (S == T) return EOF;
}
return *S++;
}
template<typename T>
inline void read(T &x) {
static char c; x = 0; int sgn = 0;
for (c = get(); c < '0' || c > '9'; c = get()) if (c == '-') sgn = 1;
for (; c >= '0' && c <= '9'; c = get()) x = x * 10 + c - '0';
if (sgn) x = -x;
}
inline int pwr(int a, int b) {
int c = 1;
while (b) {
if (b & 1) c = (ll)c * a % MOD;
b >>= 1; a = (ll)a * a % MOD;
}
return c;
}
inline int inv(int x) {
return pwr(x, MOD - 2);
}
inline int sum(int a, int b) {
a += b;
return a >= MOD ? a - MOD : a;
}
inline int sub(int a, int b) {
return a < b ? a - b + MOD : a - b;
}
inline void add(int &x, int a) {
x = sum(x, a);
}
namespace FNT {
const int MAXN = 303030;
int ww[MAXN], iw[MAXN];
int rev[MAXN];
int num;
inline void pre(int n) {
num = n;
int g = pwr(3, (MOD - 1) / n);
ww[0] = iw[0] = 1;
for (int i = 1; i < num; i++)
iw[n - i] = ww[i] = (ll)ww[i - 1] * g % MOD;
}
inline void fnt(int *a, int n, int f) {
static int x, y, *w;
w = (f == 1) ? ww : iw;
for (int i = 0; i < n; i++)
if (rev[i] > i)
swap(a[rev[i]], a[i]);
for (int i = 1; i < n; i <<= 1)
for (int j = 0; j < n; j += (i << 1))
for (int k = 0; k < i; k++) {
x = a[j + k];
y = (ll)a[j + k + i] * w[num / (i << 1) * k] % MOD;
a[j + k] = sum(x, y);
a[j + k + i] = sub(x, y);
}
if (f == -1){
int in = inv(n);
for (int i = 0; i < n; i++)
a[i] = (ll)a[i] * in % MOD;
}
}
}
inline poly operator *(poly a, poly b) {
if (a.empty() || b.empty())
return a.empty() ? a : b;
using namespace FNT;
static poly c; c.clear();
static int p[N], q[N];
int m = a.size() + b.size() - 1;
c.resize(m);
if ((ll)a.size() * b.size() <= 10000) {
for (int i = 0; i < a.size(); i++)
for (int j = 0; j < b.size(); j++)
add(c[i + j], (ll)a[i] * b[j] % MOD);
return c;
}
for (int i = 0; i < a.size(); i++) p[i] = a[i];
for (int i = 0; i < b.size(); i++) q[i] = b[i];
int l = 1, L = 0;
for (; l < m; l <<= 1) ++L; --L;
for (int i = 0; i < l; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
fnt(p, l, 1); fnt(q, l, 1);
for (int i = 0; i < l; i++)
p[i] = (ll)p[i] * q[i] % MOD;
fnt(p, l, -1);
for (int i = 0; i < m; i++) c[i] = p[i];
for (int i = 0; i < l; i++) p[i] = q[i] = 0;
return c;
}
inline poly operator +(poly a, poly b) {
static poly c; c.clear();
c.resize(max(a.size(), b.size()));
for (int i = 0; i < a.size(); i++)
c[i] = sum(c[i], a[i]);
for (int i = 0; i < b.size(); i++)
c[i] = sum(c[i], b[i]);
return c;
}
inline qua operator *(qua a, qua b) {
return qua { a.v00 * b.v00 + a.v00 * b.v10 + a.v01 * b.v00,
a.v00 * b.v01 + a.v00 * b.v11 + a.v01 * b.v01,
a.v10 * b.v00 + a.v10 * b.v10 + a.v11 * b.v00,
a.v10 * b.v01 + a.v10 * b.v11 + a.v11 * b.v01 };
}
inline arr operator *(arr a, arr b) {
return arr(a.v0 * b.v0, a.v1 * b.v1);
}
vector<int> G[N];
inline void addEdge(int from, int to) {
G[from].push_back(to);
G[to].push_back(from);
}
int n, m, clc;
int fa[N], son[N], size[N], pre[N], erp[N];
int w[N];
arr f[N];
inline void dfs1(int u) {
size[u] = 1;
pre[u] = ++clc;
erp[clc] = u;
for (int to: G[u]) {
if (to == fa[u]) continue;
fa[to] = u; dfs1(to);
size[u] += size[to];
if (size[to] > size[son[u]])
son[u] = to;
}
}
priority_queue<arr> Q;
inline arr merge(void) {
static arr a, b;
while (!Q.empty()) {
a = Q.top(); Q.pop();
if (Q.empty()) break;
b = Q.top(); Q.pop();
Q.push(a * b);
}
return a;
}
inline poly f0(void) {
poly f; f.push_back(1);
return f;
}
inline poly f1(int w) {
poly f; f.push_back(0); f.push_back(w);
return f;
}
vector<qua> lt;
inline qua divAndConq(int l, int r) {
if (l == r) return lt[l];
int mid = (l + r) >> 1;
return divAndConq(l, mid) * divAndConq(mid + 1, r);
}
inline void watch(poly x) {
cerr << "{ ";
for (int u: x) cerr << u << ", ";
cerr << "}" << endl;
}
inline void watch(arr x){
watch(x.v0); watch(x.v1);
}
inline void fuck(int v) {
lt.clear();
for (int u = v; u; u = son[u]) {
for (int to: G[u])
if (to != son[u] && to != fa[u])
Q.push(arr(f[to].v0 + f[to].v1, f[to].v0));
arr cur;
if (!Q.empty()) cur = merge();
//watch(cur);
if (cur.v0.empty()) cur.v0 = f0();
if (cur.v1.empty()) cur.v1 = f1(w[u]);
else cur.v1 = cur.v1 * f1(w[u]);
//watch(cur);
lt.push_back(qua{cur.v0, blank, blank, cur.v1});
}
qua res = divAndConq(0, lt.size() - 1);
f[v] = arr(res.v00 + res.v01, res.v10 + res.v11);
//watch(f[v]);
}
int main(void) {
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
FNT::pre(1 << 18);
read(n); read(m);
for (int i = 1; i <= n; i++) read(w[i]);
for (int i = 1; i < n; i++) {
int x, y;
read(x); read(y);
addEdge(x, y);
}
dfs1(1);
for (int i = n; i >= 1; i--) {
int u = erp[i];
if (son[fa[u]] != u) fuck(u);
}
int ans = 0;
if (f[1].v0.size() >= m) ans = sum(ans, f[1].v0[m]);
if (f[1].v1.size() >= m) ans = sum(ans, f[1].v1[m]);
cout << ans << endl;
return 0;
}