题目描述:
游游拿到了一棵树,其中每个节点上有一个数字('0'~'9')。
现在游游定义f(i)为:以i号节点为起点时,取一条路径,上面所有数字拼起来是3的倍数的方案数。
现在小红希望你求出f(1)到f(n)的值,你能帮帮她吗?注:前导零也是合法的。
题解:
暴力:从每个根开始暴力,发现会超时
正解:树形DP+换根
注意到题目要求的是拼合为3的倍数,根据3的倍数的性质,我们只需要对每位数求和看是否为3的倍数即可,动态规划状态设定如下:
设tmp[i][k]表示以i为根的路径拼和模3后为k的方案数 k取0到2
例如tmp[1][0]代表以1根的路径拼和模3后为0(3的倍数)的方案数
下面是状态转移公式: x是根,v是儿子
if (a[x] % 3 == 2) {
tmp[x][0] += tmp[v][1];
tmp[x][1] += tmp[v][2];
tmp[x][2] += tmp[v][0];
}
if (a[x] % 3 == 1) {
tmp[x][0] += tmp[v][2];
tmp[x][1] += tmp[v][0];
tmp[x][2] += tmp[v][1];
}
if (a[x] % 3 == 0) {
tmp[x][0] += tmp[v][0];
tmp[x][1] += tmp[v][1];
tmp[x][2] += tmp[v][2];
}
考虑换根
当要把根从x换到v的时候,注意到我们只需要将v对x的贡献从x中剔除,tmp[x][k]满足换根为v后的情况,此时x成为v的儿子,我们只需要再将x的贡献加到v中即可算出换根后的tmp[v][k],然后遍历换根即可。
代码如下:
//
#include <bits/stdc++.h>
using namespace std;
#define maxn 200111
#define ll long long
int n;
ll a[maxn];
ll tmp[maxn][4];
vector<ll> g[maxn];
ll f[maxn];
int mark[maxn];
void dfs(int x, int fa) {
tmp[x][a[x] % 3] = 1; //自己单独为一个路径的情况
for (int v : g[x]) {
if (v == fa)
continue;
dfs(v, x); //遍历树
if (a[x] % 3 == 2) {
tmp[x][0] += tmp[v][1];
tmp[x][1] += tmp[v][2];
tmp[x][2] += tmp[v][0];
}
if (a[x] % 3 == 1) {
tmp[x][0] += tmp[v][2];
tmp[x][1] += tmp[v][0];
tmp[x][2] += tmp[v][1];
}
if (a[x] % 3 == 0) {
tmp[x][0] += tmp[v][0];
tmp[x][1] += tmp[v][1];
tmp[x][2] += tmp[v][2];
}
}
}
//换根DFS
void dfs2(int x) {
for (int v : g[x]) {
if (mark[v])
continue;
if (!mark[v]) {
mark[v] = 1;
//注意次数算贡献一定要将儿子v的tmp值临时存储,因为遍历v之后儿子v的tmp值会变
int r0 = tmp[v][0];
int r1 = tmp[v][1];
int r2 = tmp[v][2];
if (a[x] % 3 == 2) {
tmp[x][0] -= tmp[v][1];
tmp[x][1] -= tmp[v][2];
tmp[x][2] -= tmp[v][0];
}
if (a[x] % 3 == 1) {
tmp[x][0] -= tmp[v][2];
tmp[x][1] -= tmp[v][0];
tmp[x][2] -= tmp[v][1];
}
if (a[x] % 3 == 0) {
tmp[x][0] -= tmp[v][0];
tmp[x][1] -= tmp[v][1];
tmp[x][2] -= tmp[v][2];
}
if (a[v] % 3 == 2) {
tmp[v][0] += tmp[x][1];
tmp[v][1] += tmp[x][2];
tmp[v][2] += tmp[x][0];
}
if (a[v] % 3 == 1) {
tmp[v][0] += tmp[x][2];
tmp[v][1] += tmp[x][0];
tmp[v][2] += tmp[x][1];
}
if (a[v] % 3 == 0) {
tmp[v][0] += tmp[x][0];
tmp[v][1] += tmp[x][1];
tmp[v][2] += tmp[x][2];
}
f[v] = tmp[v][0];
dfs2(v);
//回溯
if (a[x] % 3 == 2) {
tmp[x][0] += r1;
tmp[x][1] += r2;
tmp[x][2] += r0;
}
if (a[x] % 3 == 1) {
tmp[x][0] += r2;
tmp[x][1] += r0;
tmp[x][2] += r1;
}
if (a[x] % 3 == 0) {
tmp[x][0] += r0;
tmp[x][1] += r1;
tmp[x][2] += r2;
}
}
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
//第一次以1为根遍历,然后再考虑换根
dfs(1, 0);
mark[1] = 1;
f[1] = tmp[1][0];
//换根
for (int v : g[1]) {
//注意次数算贡献一定要将儿子v的tmp值临时存储,因为遍历v之后儿子v的tmp值会变
int r0 = tmp[v][0];
int r1 = tmp[v][1];
int r2 = tmp[v][2];
if (!mark[v]) {
mark[v] = 1;
if (a[1] % 3 == 2) {
tmp[1][0] -= tmp[v][1];
tmp[1][1] -= tmp[v][2];
tmp[1][2] -= tmp[v][0];
}
if (a[1] % 3 == 1) {
tmp[1][0] -= tmp[v][2];
tmp[1][1] -= tmp[v][0];
tmp[1][2] -= tmp[v][1];
}
if (a[1] % 3 == 0) {
tmp[1][0] -= tmp[v][0];
tmp[1][1] -= tmp[v][1];
tmp[1][2] -= tmp[v][2];
}
if (a[v] % 3 == 2) {
tmp[v][0] += tmp[1][1];
tmp[v][1] += tmp[1][2];
tmp[v][2] += tmp[1][0];
}
if (a[v] % 3 == 1) {
tmp[v][0] += tmp[1][2];
tmp[v][1] += tmp[1][0];
tmp[v][2] += tmp[1][1];
}
if (a[v] % 3 == 0) {
tmp[v][0] += tmp[1][0];
tmp[v][1] += tmp[1][1];
tmp[v][2] += tmp[1][2];
}
f[v] = tmp[v][0];
dfs2(v);
if (a[1] % 3 == 2) {
tmp[1][0] += r1;
tmp[1][1] += r2;
tmp[1][2] += r0;
}
if (a[1] % 3 == 1) {
tmp[1][0] += r2;
tmp[1][1] += r0;
tmp[1][2] += r1;
}
if (a[1] % 3 == 0) {
tmp[1][0] += r0;
tmp[1][1] += r1;
tmp[1][2] += r2;
}
}
}
//输出结果
for (int i = 1; i <= n; i++) {
cout << f[i] << endl;
}
}
时间复杂度O(n)