Orgrimmar 树形DP
链接:Orgrimmar
题意:
给定一棵树,选中一些点放进集合里,放进集合里的每个点最多只有一条边,最多能有多少点放进集合里。
分析:
对于一条边上的两个点,只会存在三种状态:都不选、选一个和都选。
状态表示:
f
[
i
]
[
0
/
1
/
2
]
f[i][0/1/2]
f[i][0/1/2],对于节点
i
i
i ,来说,有:
- f [ i ] [ 0 ] f[i][0] f[i][0]:不选节点 i i i。
- f [ i ] [ 1 ] f[i][1] f[i][1]:选节点 i i i且一个子节点也不选.
- f [ i ] [ 2 ] f[i][2] f[i][2]:选节点 i i i且只选一个子节点。
状态转移:
i
i
i的子树为
j
j
j,有:
f
[
i
]
[
0
]
f[i][0]
f[i][0]:不选节点
i
i
i的话,那么对于
i
i
i所有的子树来说,只取每个子树中三个状态的最大值求和。
f
[
i
]
[
0
]
+
=
m
a
x
(
f
[
j
]
[
0
]
,
m
a
x
(
f
[
j
]
[
1
]
,
f
[
j
]
[
2
]
)
)
f[i][0]+=max(f[j][0],max(f[j][1],f[j][2]))
f[i][0]+=max(f[j][0],max(f[j][1],f[j][2]))
f
[
i
]
[
1
]
f[i][1]
f[i][1]:选节点i且不选子树的话,那么只能通过子节点不选的状态
f
[
j
]
[
0
]
f[j][0]
f[j][0]来转移。
f
[
i
]
[
1
]
+
=
f
[
j
]
[
0
]
f[i][1]+=f[j][0]
f[i][1]+=f[j][0]
f
[
i
]
[
2
]
f[i][2]
f[i][2]:选节点
i
i
i 且选一个子节点,那么只需要找出来选哪个子节点后其余的子节点都不能选的和的最大值,来转移。因为这里用到了子节点不选这个状态的和,因此需要先处理出来不选子节点的和,再转移。
f
[
i
]
[
2
]
=
m
a
x
(
f
[
i
]
[
2
]
,
s
u
m
−
f
[
j
]
[
0
]
+
f
[
j
]
[
1
]
+
1
)
f[i][2]=max(f[i][2],sum-f[j][0]+f[j][1]+1)
f[i][2]=max(f[i][2],sum−f[j][0]+f[j][1]+1)
初始化 f [ i ] [ 0 ] = 0 , f [ i ] [ 1 ] = 1 f[i][0]=0,f[i][1]=1 f[i][0]=0,f[i][1]=1,有的节点没有子节点,那么 f [ i ] [ 2 ] = − I N F f[i][2]=-INF f[i][2]=−INF
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=1e6+10,M=2*N,INF=1e8;
int e[M],ne[M],h[N],idx;
int n;
int f[N][3];
int res;
void add(int a,int b)
{
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void dfs(int u,int fa)
{
int sum=0;
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==fa) continue;
dfs(j,u);
f[u][0]+=max(f[j][0],max(f[j][1],f[j][2]));
f[u][1]+=f[j][0];
sum+=f[j][0];
//f[u][2]=max(f[j][1]+1,f[u][2]);
}
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==fa) continue;
f[u][2]=max(f[u][2],sum-f[j][0]+f[j][1]+1);
}
res=max(max(res,f[u][0]),max(f[u][1],f[u][2]));
}
void solve()
{
res=idx=0;
cin>>n;
for(int i=1;i<=n;i++)
{
h[i]=-1;
f[i][1]=1;
f[i][2]=-INF;
f[i][0]=0;
}
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d %d",&a,&b);
add(a,b);
add(b,a);
}
dfs(1,-1);
cout<<res<<endl;
}
signed main() {
int size(512<<20); // 512M
__asm__ ( "movq %0, %%rsp\n"::"r"((char*)malloc(size)+size));
int t;
t=1;
cin>>t;
for(int i=1;i<=t;i++)
{
solve();
}
exit(0);
}