题意:
给定一棵无根树,求其中本质不同的独立集的个数。独立集就是一个集合中的点之间都没有边直接相连。
n
<
=
5
e
5
n<=5e5
n<=5e5,对
1
e
9
+
7
1e9+7
1e9+7取模。
题解:
首先膜拜一下
y
_
i
m
m
o
r
t
a
l
y\_immortal
y_immortal神仙,是这个神仙教的我这个题怎么做QwQ.
首先考虑没有本质不同应该怎么算。我们设 d p [ x ] [ 0 ] dp[x][0] dp[x][0]表示考虑 x x x为根的子树内不选 x x x这个点的方案数,设 d p [ x ] [ 1 ] dp[x][1] dp[x][1]表示考虑 x x x为根的子树内选 x x x这个点的方案数。我们枚举 x x x的每个子树,我们用子树的方案数乘起来就是答案。 d p [ x ] [ 0 ] = ∏ y ∈ s o n [ x ] ( d p [ y ] [ 0 ] + d p [ y ] [ 1 ] ) dp[x][0]=\prod_{y\in son[x]}(dp[y][0]+dp[y][1]) dp[x][0]=∏y∈son[x](dp[y][0]+dp[y][1]), d p [ x ] [ 1 ] = ∏ y ∈ s o n [ x ] d p [ y ] [ 0 ] dp[x][1]=\prod_{y\in son[x]}dp[y][0] dp[x][1]=∏y∈son[x]dp[y][0]
于是考虑有本质不同怎么来算。我们设 d p [ x ] [ 0 ] dp[x][0] dp[x][0]表示考虑 x x x为根的子树内不选 x x x这个点本质不同的独立集数,设 d p [ x ] [ 1 ] dp[x][1] dp[x][1]表示考虑 x x x为根的子树选 x x x这个点的本质不同的独立集数。我们把本质相同的树放在一起考虑,我们假设现在考虑到的这种本质相同的子树在 x x x的子树中有 k k k棵,这种子树的根设为 y y y节点。为了保证算方案的时候不会重复,我们给所有 y y y中可以的方案编一个号,也就是说 y y y子树中有多少种方案,最大的一个编号就是多少。之后我们为了不重复,所有相同的这些子树,规定前面的子树选的编号要小于等于后面的,这样就可以不重不漏。而这个东西应该是一个可重复的组合数。于是我们把每一类本质相同的 y y y放在一起算,有 d p [ x ] [ 0 ] = ∏ y ∈ s o n [ x ] ( d p [ y ] [ 1 ] + d p [ y ] [ 0 ] ) ∗ C d p [ y ] [ 1 ] + d p [ y ] [ 0 ] + k − 1 k dp[x][0]=\prod_{y\in son[x]}(dp[y][1]+dp[y][0])*C_{dp[y][1]+dp[y][0]+k-1}^{k} dp[x][0]=∏y∈son[x](dp[y][1]+dp[y][0])∗Cdp[y][1]+dp[y][0]+k−1k , d p [ x ] [ 1 ] = ∏ y ∈ s o n [ x ] d p [ y ] [ 0 ] ∗ C d p [ y ] [ 0 ] + k − 1 k dp[x][1]=\prod_{y\in son[x]}dp[y][0]*C_{dp[y][0]+k-1}^{k} dp[x][1]=∏y∈son[x]dp[y][0]∗Cdp[y][0]+k−1k
那么下面的问题是如何判断两棵树是否本质相同。我们要做的是树哈希。这里提供一种树哈希的方法。我们设叶子节点的哈希值是 1 1 1,然后对于其他点,我们把他们的子树按照哈希值排序,然后依次乘进去,乘的时候像字符串哈希那样,把每一个子树看作一个字符,乘一个底数的多少次幂再加进来。最后再乘一个这个子树的size。反正这样起码能保证相同的不会判成不同。
另外就是求组合数的时候,没法求 1 e 9 + 7 1e9+7 1e9+7那么大的,不过我们发现,大部分项可以被约掉,就剩下 m m m项,我们暴力算这 m m m项就好,我们每个点最多被在组合数里算一次,所以最终还是线性的。
复杂度 O ( n ) O(n) O(n)。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,hed[500010],cnt,sz[500010],mx[500010],rt,rt1,rt2,fa[500010];
long long ans,jie[500010],ni[500010],mi[500010],ha[500010],dp[500010][2];
const long long mod=1e9+7,base=23333;
struct node
{
int to,next;
}a[2000010];
inline int read()
{
int x=0;
char s=getchar();
while(s>'9'||s<'0')
s=getchar();
while(s>='0'&&s<='9')
{
x=x*10+s-'0';
s=getchar();
}
return x;
}
inline long long ksm(long long x,long long y)
{
long long res=1;
while(y)
{
if(y&1)
res=res*x%mod;
x=x*x%mod;
y>>=1;
}
return res;
}
inline void add(int from,int to)
{
a[++cnt].to=to;
a[cnt].next=hed[from];
hed[from]=cnt;
}
inline void getrt(int x,int f)
{
sz[x]=1;
mx[x]=0;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(y==f)
continue;
getrt(y,x);
sz[x]+=sz[y];
mx[x]=max(sz[y],mx[x]);
}
mx[x]=max(mx[x],n-sz[x]);
if(mx[rt1]>mx[x])
rt1=x;
else if(mx[x]==mx[rt1])
rt2=x;
}
inline void dfs(int x)
{
sz[x]=1;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(y==fa[x])
continue;
fa[y]=x;
dfs(y);
sz[x]+=sz[y];
}
}
inline int cmp(int x,int y)
{
return ha[x]<ha[y];
}
inline void dfs1(int x)
{
int num=0;
vector<int> v;
v.clear();
ha[x]=1;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(y==fa[x])
continue;
dfs1(y);
}
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(y==fa[x])
continue;
v.push_back(y);
++num;
}
sort(v.begin(),v.end(),cmp);
for(int i=0;i<num;++i)
ha[x]=(ha[x]+ha[v[i]]*mi[i+1])%mod;
ha[x]=ha[x]*sz[x]%mod;
}
inline long long C(int n,int m)
{
n%=mod;
long long res=1;
for(int i=n-m+1;i<=n;++i)
res=res*i%mod;
res=res*ni[m]%mod;
return res;
}
inline void dfs2(int x)
{
dp[x][0]=dp[x][1]=1;
int num=0;
vector<int> v;
v.clear();
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(y==fa[x])
continue;
dfs2(y);
v.push_back(y);
++num;
}
sort(v.begin(),v.end(),cmp);
int shu=1;
v.push_back(n+2);
for(int i=1;i<=num;++i)
{
if(ha[v[i]]!=ha[v[i-1]])
{
long long qwq=(dp[v[i-1]][0]+dp[v[i-1]][1])%mod,qwqq;
qwqq=dp[v[i-1]][0];
dp[x][0]=dp[x][0]*C(qwq+shu-1,shu)%mod;
dp[x][1]=dp[x][1]*C(qwqq+shu-1,shu)%mod;
shu=1;
}
else
++shu;
}
}
int main()
{
n=read();
ha[n+2]=mod+2;
for(int i=1;i<=n-1;++i)
{
int x=read(),y=read();
add(x,y);
add(y,x);
}
mx[0]=2e9;
getrt(1,0);
jie[0]=1;
for(int i=1;i<=n;++i)
jie[i]=jie[i-1]*i%mod;
ni[n]=ksm(jie[n],mod-2);
for(int i=n-1;i>=0;--i)
ni[i]=ni[i+1]*(i+1)%mod;
mi[0]=1;
for(int i=1;i<=n;++i)
mi[i]=mi[i-1]*base%mod;
if(mx[rt2]==mx[rt1])
{
rt=n+1;
for(int i=hed[rt1];i;i=a[i].next)
{
int y=a[i].to;
if(y==rt2)
{
a[i].to=rt;
break;
}
}
for(int i=hed[rt2];i;i=a[i].next)
{
int y=a[i].to;
if(y==rt1)
{
a[i].to=rt;
break;
}
}
add(rt,rt1);
add(rt,rt2);
}
else
rt=rt1;
memset(sz,0,sizeof(sz));
dfs(rt);
dfs1(rt);
dfs2(rt);
if(rt==rt1)
{
ans=(dp[rt][0]+dp[rt][1])%mod;
printf("%lld\n",ans);
return 0;
}
if(ha[rt1]==ha[rt2])
{
ans=(dp[rt1][0]*dp[rt2][1]%mod+C(dp[rt1][0]+1,2))%mod;
printf("%lld\n",ans);
}
else
{
ans=(dp[rt1][0]*dp[rt2][1]%mod+dp[rt1][1]*dp[rt2][0]%mod+dp[rt1][0]*dp[rt2][0]%mod)%mod;
printf("%lld\n",ans);
}
return 0;
}