2021 ICPC Southeastern Europe Regional Contest Werewolves(树上背包)
链接
题意:给出一个n个节点的树(
n
≤
3000
n\le3000
n≤3000),每个点有自己的颜色,好子树的定义是,子树内有一半以上的节点是同一种颜色,问有多少种划分子树的方法,最后对
998244353
998244353
998244353取模。
思路:dls讲的树上背包。。当时听了也没太明白代码怎么写,现在想想还是对树上背包这种不太熟。我们对于每个颜色,都要在树上进行一次dp,这样就可以,对于颜色
i
i
i,相同看作1,不同是-1,最后统计所有的和大于0的组合方法,所以我们定义数组
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]代表以
i
i
i为根,总和为
j
j
j的方案数,正常开成2倍表示负数就行,但是懒得了,就开成
d
p
[
i
]
[
j
]
[
0
/
1
]
dp[i][j][0/1]
dp[i][j][0/1],0代表负数,1代表正数,最后单独开一个数组代表刚好为0。现在先看转移
①对于
u
u
u的每一个子树
v
v
v(用
s
u
m
i
sum_i
sumi代表
i
i
i号点的和),如果
s
u
m
u
+
s
u
m
v
≤
c
n
t
这
个
颜
色
的
总
数
sum_u+sum_v\le cnt_{这个颜色的总数}
sumu+sumv≤cnt这个颜色的总数那么首先有:
d
p
[
u
]
[
j
+
k
]
[
1
]
=
d
p
[
u
]
[
j
+
k
]
[
1
]
+
o
r
i
[
u
]
[
j
]
[
1
]
∗
d
p
[
v
]
[
k
]
[
1
]
d
p
[
u
]
[
j
+
k
]
[
0
]
=
d
p
[
u
]
[
j
+
k
]
[
0
]
+
o
r
i
[
u
]
[
j
]
[
0
]
∗
d
p
[
v
]
[
k
]
[
0
]
\begin{array}{c} dp[u][j + k][1] = dp[u][j + k][1] + ori[u][j][1] * dp[v][k][1] \\ dp[u][j + k][0] = dp[u][j + k][0] + ori[u][j][0] * dp[v][k][0] \end{array}
dp[u][j+k][1]=dp[u][j+k][1]+ori[u][j][1]∗dp[v][k][1]dp[u][j+k][0]=dp[u][j+k][0]+ori[u][j][0]∗dp[v][k][0],其中
o
r
i
ori
ori是这个状态在进行这次转移之前初始状态,正确性是因为,对于一个总和,那他肯定是
u
u
u的总和为
j
j
j的情况和
v
v
v的总和为
k
k
k的情况组合起来。
②对于一个子树
v
v
v,如果
s
u
m
u
≥
s
u
m
v
sum_u \ge sum_v
sumu≥sumv那么就是,对于一个情况,可以有
s
u
m
u
−
s
u
m
v
sum_u-sum_v
sumu−sumv,转移方程就为
d
p
[
u
]
[
j
−
k
]
[
0
]
=
d
p
[
u
]
[
j
−
k
]
[
0
]
+
o
r
i
[
u
]
[
j
]
[
0
]
∗
d
p
[
v
]
[
k
]
[
1
]
d
p
[
u
]
[
j
−
k
]
[
1
]
=
d
p
[
u
]
[
j
−
k
]
[
1
]
+
o
r
i
[
u
]
[
j
]
[
1
]
∗
d
p
[
v
]
[
k
]
[
0
]
\begin{array}{c} dp[u][j-k][0] = dp[u][j-k][0] + ori[u][j][0]*dp[v][k][1] \\ dp[u][j-k][1] = dp[u][j-k][1] + ori[u][j][1]*dp[v][k][0] \end{array}
dp[u][j−k][0]=dp[u][j−k][0]+ori[u][j][0]∗dp[v][k][1]dp[u][j−k][1]=dp[u][j−k][1]+ori[u][j][1]∗dp[v][k][0]。
③ 跟②情况刚好相反
④ 两个相减刚好为0,就有转移方程
d
p
0
[
u
]
=
d
[
u
]
+
o
r
i
[
u
]
[
j
]
[
0
]
∗
d
p
[
v
]
[
k
]
[
1
]
+
o
r
i
[
u
]
[
j
]
[
1
]
∗
d
p
[
v
]
[
k
]
[
0
]
dp0[u] = d[u] + ori[u][j][0] * dp[v][k][1] + ori[u][j][1] * dp[v][k][0]
dp0[u]=d[u]+ori[u][j][0]∗dp[v][k][1]+ori[u][j][1]∗dp[v][k][0]
⑤对于
u
u
u来说,每次转移最开始,所有的次数都可以从
u
u
u的和为0的状态转移
d
p
[
u
]
[
j
]
[
0
]
=
d
p
[
u
]
[
j
]
[
0
]
+
o
r
i
0
[
u
]
∗
d
p
[
v
]
[
j
]
[
0
]
d
p
[
u
]
[
j
]
[
1
]
=
d
p
[
u
]
[
j
]
[
1
]
+
o
r
i
0
[
u
]
∗
d
p
[
v
]
[
j
]
[
1
]
d
p
0
[
u
]
=
d
p
0
[
u
]
+
o
r
i
0
[
u
]
∗
d
p
0
[
v
]
\begin{array}{c} dp[u][j][0] = dp[u][j][0] + ori0[u] * dp[v][j][0]\\ dp[u][j][1] = dp[u][j][1] + ori0[u] * dp[v][j][1]\\ dp0[u] = dp0[u] + ori0[u] * dp0[v]\end{array}
dp[u][j][0]=dp[u][j][0]+ori0[u]∗dp[v][j][0]dp[u][j][1]=dp[u][j][1]+ori0[u]∗dp[v][j][1]dp0[u]=dp0[u]+ori0[u]∗dp0[v]
所有的情况就讨论完了,但是这样写上去很明显是一个
n
3
n^3
n3的算法,所以加上优化,对于每个颜色,记录他的
c
n
t
cnt
cnt,如果对于dfs内部的循环,如果他们枚举的和,最大就是m,并且不会超过他节点的总数size,所以每个循环应该小于
m
i
n
(
m
,
s
i
z
e
)
min(m, size)
min(m,size)。
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MOD = 998244353;
const int N = 3e3+10;
int dp[N][N][2], tmp[N][N][2], d[N], tp[N], ans, val[N], c[N], vis[N], n, m;
int head[N], idx;
struct Edge{int to, nxt;}e[N << 1];
void add(int u, int v) {e[++idx].to = v, e[idx].nxt = head[u], head[u] = idx;}
int dfs(int u, int fa)
{
int p = 1;
if (val[u]) dp[u][1][1] = 1;
else dp[u][1][0] = 1;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v == fa) continue;
int siz = dfs(v, u);
tp[u] = d[u];
for (int j = 1; j <= min(p, m); j++) {
tmp[u][j][0] = dp[u][j][0];
tmp[u][j][1] = dp[u][j][1];
}
d[u] = (d[u] + d[v] * tp[u]) % MOD;
for (int j = 1; j <= min(siz, m); j++) {
dp[u][j][0] = (dp[u][j][0] + tp[u] * dp[v][j][0]) % MOD;
dp[u][j][1] = (dp[u][j][1] + tp[u] * dp[v][j][1]) % MOD;
}
for (int j = 1; j <= min(p, m); j++) {
dp[u][j][1] = (dp[u][j][1] + tmp[u][j][1] * d[v]) % MOD;
dp[u][j][0] = (dp[u][j][0] + tmp[u][j][0] * d[v]) % MOD;
for (int k = 1; k <= min(m, siz); k++) {
if (k + j <= m) {
dp[u][k + j][1] = (dp[u][k + j][1] + tmp[u][j][1] * dp[v][k][1]) % MOD;
dp[u][k + j][0] = (dp[u][k + j][0] + tmp[u][j][0] * dp[v][k][0]) % MOD;
}
if (j - k >= 1) {
dp[u][j - k][1] = (dp[u][j - k][1] + tmp[u][j][1] * dp[v][k][0]) % MOD;
dp[u][j - k][0] = (dp[u][j - k][0] + tmp[u][j][0] * dp[v][k][1]) % MOD;
}
if (k - j >= 1) {
dp[u][k - j][1] = (dp[u][k - j][1] + tmp[u][j][0] * dp[v][k][1]) % MOD;
dp[u][k - j][0] = (dp[u][k - j][0] + tmp[u][j][1] * dp[v][k][0]) % MOD;
}
if (j == k) {
d[u] = (d[u] + tmp[u][j][0] * dp[v][k][1] + tmp[u][j][1] * dp[v][k][0]) % MOD;
}
}
}
p += siz;
}
for (int i = 1; i <= min(p, m); i++) {
ans = (ans + dp[u][i][1]) % MOD;
}
return p;
}
signed main()
{
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> c[i];
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add(u, v); add(v, u);
}
for (int i = 1; i <= n; i++) {
if (vis[c[i]]) continue;
m = 0; vis[c[i]] = 1;
for (int j = 1; j <= n; j++) {
if (c[i] == c[j]) {
m++; val[j] = 1;
}else val[j] = 0;
}
for (int j = 1; j <= n; j++) {
d[j] = 0;
for (int k = 1; k <= m; k++) {
dp[j][k][0] = dp[j][k][1] = 0;
}
}
dfs(1, 0);
}
cout << ans;
return 0;
}