Description
Input
Output
求一棵树编号序列不同的方案数:
令 $f[u],g[u]$ 分别表示 $u$ 选/不选 的方案数.
则 $f[u]=\prod_{v\in son[u]}g[v]$,$g[u]=\prod_{v\in son[u]}g[v]+f[v]$.
然而如果要求本质不同,那么那些子树结构相同的就会算重.
假设有 $k$ 个儿子树形态相同,每一个儿子可选的方案为 $h$.
则我们要求给每一个儿子都分一种方案的方案数.
即有 $m$ 个相同的盒子,有 $k$ 种球,求给每一个盒子分配一个球(可重复)的方案数.
这个直接用可重集公式即可,即 $C_{k+m-1}^{m}$.
如何求得所有形态相同得子树呢?
这棵树无论如何旋转,重心都是不变的,以重心(或两重心之间连一个点)为根,进行树哈希+树形DP即可.
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
typedef long long ll;
const int N=500003,mod=1000000007,mul=20011118,ha=20011118,con=2019;
vector<int>rt;
ll F[N],G[N];
int n,edges,M,root;
int hd[N],to[N<<1],nex[N<<1],mx[N],siz[N],Hash[N],sta[N];
ll qpow(ll base,ll k)
{
ll tmp=1ll;
for(;k;base=(base*base)%mod,k>>=1) if(k&1) tmp=(tmp*base)%mod;
return tmp;
}
ll inv(int a) { return qpow((ll)a, (ll)mod-2); }
bool cmp(int a,int b)
{
return Hash[a]<Hash[b];
}
inline void addedge(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void getroot(int u,int ff)
{
siz[u]=1,mx[u]=0;
for(int i=hd[u];i;i=nex[i])
if(to[i]!=ff)
getroot(to[i],u),siz[u]+=siz[to[i]],mx[u]=max(mx[u],siz[to[i]]);
M=min(M,mx[u]=max(mx[u],n-siz[u]));
}
ll C(int a,int b)
{
ll tmp=1;
for(int i=a-b+1;i<=a;++i) tmp=(1ll*i*tmp)%mod;
for(int i=1;i<=b;++i) tmp=(1ll*inv(i)*tmp)%mod;
return tmp;
}
void calc(int u,int ff)
{
int i,j,tmp=0;
Hash[u]=2019;
for(i=hd[u];i;i=nex[i])
if(to[i]!=ff)
calc(to[i],u);
sta[0]=0;
for(i=hd[u];i;i=nex[i])
if(to[i]!=ff)
sta[++sta[0]]=to[i];
sort(sta+1,sta+1+sta[0],cmp);
for(i=1;i<=sta[0];++i) Hash[u]=((ll)(Hash[u]*mul)^Hash[sta[i]])%ha;
F[u]=G[u]=1ll;
for(i=1;i<=sta[0];i=j+1)
{
j=i;
while(j<sta[0]&&Hash[sta[j+1]]==Hash[sta[j]]) ++j;
F[u]=(F[u]*C(G[sta[i]]+j-i, j-i+1))%mod;
G[u]=(G[u]*C(G[sta[i]]+F[sta[i]]+j-i, j-i+1))%mod;
}
}
int main()
{
int i,j;
// setIO("input");
scanf("%d",&n);
for(i=1;i<n;++i)
{
int x,y;
scanf("%d%d",&x,&y),addedge(x,y),addedge(y,x);
}
M=n,getroot(1,0);
for(i=1;i<=n;++i) if(mx[i]==M) rt.push_back(i);
if(rt.size()==2)
{
int pre;
root=++n;
addedge(n,rt[0]),addedge(n,rt[1]);
if(to[hd[rt[0]]]==rt[1]) hd[rt[0]]=nex[hd[rt[0]]];
else
{
for(pre=i=hd[rt[0]];i;pre=i,i=nex[i])
if(to[i]==rt[1]) { nex[pre]=nex[i]; break; }
}
if(to[hd[rt[1]]]==rt[0]) hd[rt[1]]=nex[hd[rt[1]]];
else
{
for(pre=i=hd[rt[1]];i;pre=i,i=nex[i])
if(to[i]==rt[0]) { nex[pre]=nex[i]; break; }
}
}else root=rt[0];
calc(root,0);
if(rt.size()==1) printf("%lld\n",(F[root]+G[root])%mod);
else
{
int a=rt[0],b=rt[1];
if(Hash[a]==Hash[b]) printf("%lld\n",(G[root]-C(F[a]+1,2)+mod)%mod);
else printf("%lld\n", (((F[a]*F[b])%mod) + ((F[a]*G[b])%mod) + ((G[a]*G[b])%mod)%mod));
}
return 0;
}