题目链接
本题树形dp,分类讨论分析状态转移。
设 :
d
p
[
u
]
[
t
a
g
]
dp[u][tag]
dp[u][tag]表示以
u
u
u为根节点,当
u
u
u为
t
a
g
tag
tag状态时,子树的总
g
o
o
d
good
good点数量;(
t
a
g
=
=
1
tag == 1
tag==1为
g
o
o
d
good
good)
d
p
_
[
u
]
[
t
a
g
]
dp\_[u][tag]
dp_[u][tag]表示以
u
u
u为根节点,当
u
u
u为
t
a
g
tag
tag状态时,子树的总重;
v
a
l
[
u
]
[
t
a
g
]
val[u][tag]
val[u][tag]表示节点
u
u
u为
t
a
g
tag
tag状态时的权值;
d
e
g
[
u
]
deg[u]
deg[u]表示节点的度数。
第一遍
d
f
s
dfs
dfs,状态转移:
1.
n
=
=
2
n==2
n==2时:即两个节点都为
g
o
o
d
.
good.
good.
2.
n
!
=
2
n!=2
n!=2时:
核心思想是:当一个节点
g
o
o
d
good
good时(
n
!
=
2
n!=2
n!=2),此节点的父节点和子节点都不是
g
o
o
d
good
good。
第二遍
d
f
s
dfs
dfs,求每个节点的权值:
#include<iostream>
#include<cstring>
using namespace std;
typedef long long ll;
struct edge
{
int next;
int to;
};
const int N = 2e5 + 10;
int n, cnt;
int head[N];
edge table[N * 2];
int deg[N];
ll val[N][2];
int dp[N][2];
int dp_[N][2];
int ans[N];
inline void add(int& u, int& v)
{
table[++cnt].next = head[u];
head[u] = cnt;
table[cnt].to = v;
}
void dfs(int u, int fa)
{
int e, v;
val[u][1] = deg[u], val[u][0] = 1;
dp[u][1] = 1, dp[u][0] = 0;
dp_[u][1] = val[u][1], dp_[u][0] = val[u][0];
for (e = head[u]; e ; e = table[e].next)
{
v = table[e].to;
if (v == fa) continue;
dfs(v, u);
dp[u][1] += dp[v][0];
dp[u][0] += max(dp[v][0], dp[v][1]);
dp_[u][1] += dp_[v][0];
if (dp[v][0] > dp[v][1])
dp_[u][0] += dp_[v][0];
else if (dp[v][0] < dp[v][1])
dp_[u][0] += dp_[v][1];
else
dp_[u][0] += min(dp_[v][0], dp_[v][1]);
}
}
void get_ans(int u,int fa,bool tag)
{
int e, v;
ans[u] = val[u][tag];
for (e = head[u]; e; e = table[e].next)
{
v = table[e].to;
if (v == fa) continue;
if (tag)
get_ans(v, u, 0);
else
{
if (dp[v][0] > dp[v][1])
get_ans(v, u, 0);
else if (dp[v][0] < dp[v][1])
get_ans(v, u, 1);
else
{
if (dp_[v][0] < dp_[v][1])
get_ans(v, u, 0);
else
get_ans(v, u, 1);
}
}
}
}
int main()
{
int i, u, v;
int nums;
ll res;
scanf("%d", &n);
for (i = 1; i < n; ++i)
{
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
++deg[u], ++deg[v];
}
if (n == 2) //特判
{
printf("%d %d\n%d %d", 2, 2, 1, 1);
return 0;
}
dfs(1, 1);
if (dp[1][1] > dp[1][0])
nums = dp[1][1], res = dp_[1][1], get_ans(1,1,1);
else if (dp[1][1] < dp[1][0])
nums = dp[1][0], res = dp_[1][0], get_ans(1,1,0);
else
{
nums = dp[1][0], res = min(dp_[1][0], dp_[1][1]);
if (dp_[1][0] < dp_[1][1])
get_ans(1, 1, 0);
else
get_ans(1, 1, 1);
}
printf("%d %lld\n", nums, res);
for (i = 1; i <= n; ++i)
printf("%d ", ans[i]);
return 0;
}