题目大意
给定一颗n节点树,问:在选择最少点情况下,使得树上所有点都与这些点中至少一个有边连接(除了选中的点)的方案数以及选择点数
这题,我们可以考虑用树形dp做,一个点的状态,0为没被选定且儿子节点都没被选定,1为没被选定且儿子节点至少一个被选定,2为被选定。
那么我们可以发现,当状态为2时,子节点任意状态都可以,保留选的点数最少的(同一儿子,多个方案就求和,不同儿子用乘法)
当状态为0时,儿子必然是1,不可能为0,否则儿子就不合法了。
当状态为1时,儿子至少有一个为2,这样我们可以枚举哪个儿子为2,为避免重复计算,我们按顺序做,保证在当前儿子为2时,前面做过的儿子都为1(不可能为0),后面的可1可2。(这个可以dp时先求出一个前缀积数组和后缀积数组表示方案数,还有前缀和和后缀和表示最小的选择点数,当然也可以用逆元来做,只需特判子节点里有0就可以了)
对于不存在的方案我们先对它选择点数赋值成无穷大(程序里我赋成了n
+1)
最后判定输出即可。
贴代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 100002
#define MOD 1000000007
using namespace std;
int n,ans1,ans,ff,ssum;
int g[N],a[N+N][2],f[N][3],sum[N][3],d[N],fdrev[N],sumdrev[N],fdord[N],sumdord[N];
struct node{
int v,num;
}b[3];
void ins(int x,int y){
static int sum=0;
a[++sum][0]=y,a[sum][1]=g[x],g[x]=sum;
}
void init(){
static int x,y;
scanf("%d",&n);
for (int i=1;i<n;i++)
scanf("%d %d",&x,&y),ins(x,y),ins(y,x);
}
bool cmp(const node&a,const node&b){
return a.num<b.num;
}
void did(int x){
if (sum[x][1]<sum[x][2])
ff=f[x][1],ssum=sum[x][1];
else
if (sum[x][1]==sum[x][2])
ff=(f[x][1]+f[x][2])%MOD,ssum=sum[x][1];
else
ff=f[x][2],ssum=sum[x][2];
}
void dfs(int x,int fa){
static int y,z,s,ss;
bool p,p1;
p=0,p1=1;
f[x][0]=1,f[x][2]=1;
sum[x][0]=0;
sum[x][2]=1;
sum[x][1]=N;
for (int i=g[x];i;i=a[i][1])
if (fa!=a[i][0]){
p=1;
dfs(a[i][0],x);
y=a[i][0];
b[0].v=f[y][0],b[0].num=sum[y][0];
b[1].v=f[y][1],b[1].num=sum[y][1];
b[2].v=f[y][2],b[2].num=sum[y][2];
sort(b,b+3,cmp);
z=b[0].v;
for (int i=1;i<=2;i++)
if (b[i].num==b[i-1].num)
(z+=b[i].v)%=MOD;
else
break;
sum[x][2]+=b[0].num;
f[x][2]=(long long)f[x][2]*z%MOD;
if (sum[y][1]==N)p1=0;
else
f[x][0]=(long long)f[x][0]*f[y][1]%MOD,sum[x][0]+=sum[y][1];
}
if (!p1)sum[x][0]=N;
if (p){
d[0]=0;
for (int i=g[x];i;i=a[i][1])
if (fa!=a[i][0]){
d[++d[0]]=y=a[i][0];
did(y);
fdrev[d[0]]=ff;
fdord[d[0]]=f[y][1];
sumdrev[d[0]]=ssum;
sumdord[d[0]]=sum[y][1];
}
fdrev[d[0]+1]=fdord[0]=1;
for (int i=2;i<=d[0];i++){
fdord[i]=(long long)fdord[i]*fdord[i-1]%MOD,sumdord[i]+=sumdord[i-1];
if (sumdord[i]>N)sumdord[i]=N;
}
sumdrev[d[0]+1]=0;
for (int i=d[0]-1;i;i--){
fdrev[i]=(long long)fdrev[i]*fdrev[i+1]%MOD,sumdrev[i]+=sumdrev[i+1];
if (sumdrev[i]>N)sumdrev[i]=N;
}
for (int i=1;i<=d[0];i++){
y=d[i];
s=sumdord[i-1]+sumdrev[i+1]+sum[y][2];
ss=(long long)fdord[i-1]*fdrev[i+1]%MOD*f[y][2]%MOD;
if (s<sum[x][1])
sum[x][1]=s,f[x][1]=ss;
else
if (s==sum[x][1])
f[x][1]=(f[x][1]+ss)%MOD;
}
}
}
void work(){
dfs(1,0);
if (sum[1][1]<sum[1][2])
ans=sum[1][1],ans1=f[1][1];
else
if (sum[1][1]==sum[1][2])
ans=sum[1][1],ans1=(f[1][1]+f[1][2])%MOD;
else
ans=sum[1][2],ans1=f[1][2];
}
void write(){
printf("%d\n",ans);
printf("%d\n",ans1);
}
int main(){
init();
work();
write();
return 0;
}