传送门
思路:考虑对每个颜色分别进行计算。如果该节点的颜色为当前计算的颜色,令
v
a
l
i
=
1
val_i = 1
vali=1,否则
v
a
l
i
=
−
1
val_i = -1
vali=−1。显然要使子树颜色相同的数量严格大于一半等价于这个子树和大于等于1。
求一个子树中和大于等于1的方案数,我们可以树上背包。
定义: d p 1 [ u ] [ s ] dp_1[u][s] dp1[u][s]: 以u为根的子树,和为 s ( s ≥ 1 ) s(s\geq 1) s(s≥1)的方案数。 d p 2 [ u ] [ s ] dp_2[u][s] dp2[u][s]:和为 − s ( s ≥ 1 ) -s(s \geq 1) −s(s≥1)的方案数。 d p 3 [ u ] dp_3[u] dp3[u]和为 0 0 0的方案数
状态转移:
-
i + j < = m i+j <= m i+j<=m: d p 1 [ u ] [ i + j ] = d p 1 [ u ] [ i ] ∗ d p 1 [ v ] [ j ] dp_1[u][i+j] = dp_1[u][i] * dp_1[v][j] dp1[u][i+j]=dp1[u][i]∗dp1[v][j]
-
i > j i > j i>j : d p 1 [ u ] [ i − j ] = d p 1 [ u ] [ i ] ∗ d p 2 [ v ] [ j ] dp_1[u][i-j] = dp_1[u][i] * dp_2[v][j] dp1[u][i−j]=dp1[u][i]∗dp2[v][j]
-
i < j i < j i<j: d p 1 [ u ] [ j − i ] = d p 2 [ u ] [ i ] ∗ d p 1 [ v ] [ j ] dp_1[u][j-i] = dp_2[u][i] * dp_1[v][j] dp1[u][j−i]=dp2[u][i]∗dp1[v][j]
r e s = ∑ u ∑ i = 1 m d p 1 [ u ] [ i ] res = \sum _u\sum_{i=1}^{m} dp_1[u][i] res=∑u∑i=1mdp1[u][i]
d p 2 dp_2 dp2的转移类似, d p 3 dp_3 dp3的转移也十分好推。
每次计算的复杂度: O ( n ∗ m ) O(n*m) O(n∗m),总的复杂度: O ( n ∗ ∑ m ) O(n*\sum m) O(n∗∑m) = O ( n 2 ) O(n^2) O(n2), m m m为每个颜色的数量。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 3000 + 10;
const int mod = 998244353;
vector<int> g[MAXN];
int n, m;
int val[MAXN], c[MAXN], vis[MAXN];
ll dp1[MAXN][MAXN], dp2[MAXN][MAXN], dp3[MAXN];
ll tmp1[MAXN][MAXN], tmp2[MAXN][MAXN], tmp3[MAXN];
ll res;
int dfs(int u, int fa) {
int p = 1;
if(val[u] == 1) dp1[u][1] = 1;
else dp2[u][1] = 1;
for(auto v : g[u]) {
if(v == fa) continue;
int sz = dfs(v, u);
for(int i = 0; i <= min(p, m); i++) {
tmp1[u][i] = dp1[u][i];
tmp2[u][i] = dp2[u][i];
tmp3[u] = dp3[u];
}
dp3[u] = (dp3[u] + tmp3[u] * dp3[v]) % mod;
for(int j = 1; j <= sz && j <= m; j++) {
dp1[u][j] = (dp1[u][j] + tmp3[u] * dp1[v][j]) % mod;
dp2[u][j] = (dp2[u][j] + tmp3[u] * dp2[v][j]) % mod;
}
for(int i = 1; i <= min(p, m); i++) {
dp1[u][i] = (dp1[u][i] + tmp1[u][i] * dp3[v]) % mod;
dp2[u][i] = (dp2[u][i] + tmp2[u][i] * dp3[v]) % mod;
for(int j = 1; j <= sz && j <= m; j++) {
if(i+j <= m) {
dp1[u][i+j] = (dp1[u][i+j] + tmp1[u][i] * dp1[v][j]) % mod;
dp2[u][i+j] = (dp2[u][i+j] + tmp2[u][i] * dp2[v][j]) % mod;
}
if(i-j >= 1) {
dp1[u][i-j] = (dp1[u][i-j] + tmp1[u][i] * dp2[v][j]) % mod;
dp2[u][i-j] = (dp2[u][i-j] + tmp2[u][i] * dp1[v][j]) % mod;
}
if(j-i >= 1) {
dp1[u][j-i] = (dp1[u][j-i] + tmp2[u][i] * dp1[v][j]) % mod;
dp2[u][j-i] = (dp2[u][j-i] + tmp1[u][i] * dp2[v][j]) % mod;
}
if(i == j) {
dp3[u] = (dp3[u] + tmp1[u][i] * dp2[v][j] + tmp2[u][i] * dp1[v][j]) % mod;
}
}
}
p += sz;
}
for(int i = 1; i <= min(p, m); i++) {
res = (res + dp1[u][i]) % mod;
}
return p;
}
void solve() {
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> c[i];
}
for(int i = 1; i < n; i++) {
int u, v; cin >> v >> u;
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1; i <= n; i++) {
if(vis[c[i]]) continue;
vis[c[i]] = 1; m = 0;
for(int j = 1; j <= n; j++) {
val[j] = (c[j] == c[i] ? 1 : -1);
if(c[j] == c[i]) m++;
}
for(int j = 1; j <= n; j++) {
for(int k = 0; k <= m; k++) {
dp1[j][k] = dp2[j][k] = dp3[j] = 0;
}
}
dfs(1, 0);
}
cout << res << "\n";
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
solve();
return 0;
}