一棵树,点可以染色,但是染色的点不能相邻,如果不考虑同构那这就是很简单的树形DP,设
f[i][1]
为
i
节点染色子树的总方案数,设
先考虑判断同构,将树hash一下,hash的方式怎么弄其实都行,比如将所有孩子的hash值加起来,子树大小为size,就再乘一个
然后考虑树形DP,将孩子里同构的孩子合并起来后对于剩余的就可以直接按照无同构的情况转移,同构的合并,假设有k个孩子同构,他们的
树根要选择树的中心,因为树不管怎么转他的重心是不变的,如果这棵树有两个重心还要记得新建一个节点连向这两个重心作为新的树根,最后计算答案的时候处理一下
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn = 510000;
const ll Mod = 1e9+7;
struct edge
{
int y,nex;
edge(){}
edge(int _y,int _nex){y=_y;nex=_nex;}
}a[maxn<<1]; int len,fir[maxn];
void ins(int x,int y){a[++len]=edge(y,fir[x]);fir[x]=len;}
int n,siz[maxn];
ll ny[maxn];
void get_ny()
{
ny[1]=1;
for(ll i=2;i<maxn;i++)
ny[i]=((-(Mod/i)*ny[Mod%i])%Mod+Mod)%Mod;
}
void find_root(int x,int f,int &rt,int sum)
{
siz[x]=1;
bool flag=true;
for(int k=fir[x];k;k=a[k].nex)
{
int y=a[k].y;
if(y!=f)
{
find_root(y,x,rt,sum);
if(siz[y]*2>sum) flag=false;
siz[x]+=siz[y];
}
}
if((sum-siz[x])*2>sum) flag=false;
if(flag) rt=x;
}
bool ifroot(int x,int sum)
{
bool flag=true;
for(int k=fir[x];k;k=a[k].nex)
{
int y=a[k].y;
if(siz[y]>siz[x])
{
if((sum-siz[x])*2>sum) flag=false;
}
else if(siz[y]*2>sum) flag=false;
}
return flag;
}
int get_root()
{
int x;
find_root(1,0,x,n);
for(int k=fir[x];k;k=a[k].nex)
{
int y=a[k].y;
if(ifroot(y,n))
{
n++;
ins(n,x);
ins(n,y);
for(int k=fir[x];k;k=a[k].nex)
{
if(a[k].y==y) { a[k].y=n; break;}
}
for(int k=fir[y];k;k=a[k].nex)
{
if(a[k].y==x) { a[k].y=n; break; }
}
return n;
}
}
return x;
}
struct node
{
int h,x;
node(){}
node(int _h,int _x){h=_h;x=_x;}
}t[maxn]; int tp[maxn];
bool cmp(node x,node y) {return x.h<y.h;}
ll f[maxn][2];
int hash[maxn],pw[maxn];
vector<node>v[maxn];
ll C(ll x,ll y)
{
x%=Mod; ll r=1,cc=0;
for(ll i=x;i>=x-y+1;i--) cc++,r=r*i%Mod*ny[cc]%Mod;
return r;
}
void solve(int x,int fa)
{
siz[x]=1;
hash[x]=pw[1];
f[x][1]=f[x][0]=1;
for(int k=fir[x];k;k=a[k].nex)
{
int y=a[k].y;
if(y==fa)continue;
solve(y,x);
v[x].push_back(node(hash[y],y));
siz[x]+=siz[y];
hash[x]+=hash[y]*pw[siz[y]];
}
if(siz[x]!=1)
{
tp[x]=0;
for(int i=0;i<v[x].size();i++) t[++tp[x]]=v[x][i];
sort(t+1,t+tp[x]+1,cmp);
for(int i=1;i<=tp[x];i++)
{
int j;
for(j=i;j<tp[x]&&t[j].h==t[j+1].h&&f[t[j].x][0]==f[t[j+1].x][0];j++);
ll t0=f[t[i].x][0],t1=t0+f[t[i].x][1];
t0=C(t0+j-i,j-i+1);
t1=C(t1+j-i,j-i+1);
(f[x][0]*=t1)%=Mod;
(f[x][1]*=t0)%=Mod;
i=j;
}
}
}
int main()
{
get_ny();
scanf("%d",&n); int N=n;
for(int i=1;i<n;i++)
{
int x,y; scanf("%d%d",&x,&y);
ins(x,y); ins(y,x);
}
int rt=get_root();
pw[0]=1; for(int i=1;i<=n;i++) pw[i]=pw[i-1]*233;
solve(rt,0);
if(rt==N+1)
{
int tx=0,ty=0;
for(int k=fir[rt];k;k=a[k].nex)
{
if(!ty)ty=a[k].y;
else tx=a[k].y;
}
if(hash[tx]==hash[ty]&&f[tx][0]==f[ty][0])
printf("%lld\n",(f[tx][0]*f[ty][1]%Mod+C(f[tx][0]+1ll,2ll))%Mod);
else printf("%lld\n",
(f[tx][0]*f[ty][1]%Mod+f[tx][1]*f[ty][0]%Mod+f[tx][0]*f[ty][0]%Mod)%Mod);
}
else printf("%lld\n",(f[rt][0]+f[rt][1])%Mod);
return 0;
}