题意
JYY有两棵树A和B:树A有N个点,编号为1到N;树B有N+1个点,编号为1到N+1。JYY知道树B恰好是由树A加上一个叶节点,然后将节点的编号打乱后得到的。他想知道,这个多余的叶子到底是树B中的哪一个叶节点呢?
题解
好久没有写过树hash了。。并不知道怎么写简单
Rose告诉了我一个不错的hash方法
我们只需要
f
x
=
b
a
s
e
×
(
Π
f
s
o
n
+
t
o
t
x
)
f_x=base\times (\Pi f_{son}+tot_x)
fx=base×(Πfson+totx)
f
f
f是每一个子树的hash值
可以发现,这玩意支持换根,那么就不限于找重心了
那第一颗树所有的hash值丢进去
第二颗树删除一个节点的hash值也可以用类似的方法弄出来
弄个set,找一找有没有出现过就好了
一开始base设为2333,模数为1e9+7被卡了。。
然后把base改为了233才过
em…脸有点黑
CODE:
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<set>
using namespace std;
typedef long long LL;
const int MOD=1e9+7;
const int base=233;
const int N=100005;
int Inv;
int add (int x,int y) {x=x+y;return x>=MOD?x-MOD:x;}
int mul (int x,int y) {return (LL)x*y%MOD;}
int dec (int x,int y) {x=x-y;return x<0?x+MOD:x;}
int Pow (int x,int y)
{
if (y==1) return x;
int lalal=Pow(x,y>>1);
lalal=mul(lalal,lalal);
if (y&1) lalal=mul(lalal,x);
return lalal;
}
int n;
struct qq
{
int x,y,last;
}e[N*2];int num,last[N];
void init (int x,int y)
{
e[++num].x=x;e[num].y=y;
e[num].last=last[x];
last[x]=num;
}
int f[N];
int tot[N];
void dfs (int x,int fa)
{
tot[x]=1;f[x]=1;
for (int u=last[x];u!=-1;u=e[u].last)
{
int y=e[u].y;
if (y==fa) continue;
dfs(y,x);
tot[x]=tot[x]+tot[y];
f[x]=mul(f[x],f[y]);
}
f[x]=add(f[x],tot[x]);
f[x]=mul(f[x],base);
}
set<int> s;
void dfs1 (int x,int fa)
{
s.insert(f[x]);
int lalal=dec(mul(f[x],Inv),n),xx;
for (int u=last[x];u!=-1;u=e[u].last)
{
int y=e[u].y;
if (y==fa) continue;
xx=mul(lalal,Pow(f[y],MOD-2));
xx=add(xx,n-tot[y]);
xx=mul(xx,base);
f[y]=mul(f[y],Inv);
f[y]=dec(f[y],tot[y]);
f[y]=mul(f[y],xx);
f[y]=add(f[y],n);
f[y]=mul(f[y],base);
dfs1(y,x);
}
}
int du[N];
void dfs2 (int x,int fa)
{
int lalal=dec(mul(f[x],Inv),n),xx;
for (int u=last[x];u!=-1;u=e[u].last)
{
int y=e[u].y;
if (y==fa) continue;
xx=mul(lalal,Pow(f[y],MOD-2));
xx=add(xx,n-tot[y]);
xx=mul(xx,base);
f[y]=mul(f[y],Inv);
f[y]=dec(f[y],tot[y]);
f[y]=mul(f[y],xx);
f[y]=add(f[y],n);
f[y]=mul(f[y],base);
dfs2(y,x);
}
}
int main()
{
Inv=Pow(base,MOD-2);
num=0;memset(last,-1,sizeof(last));
scanf("%d",&n);
for (int u=1;u<n;u++)
{
int x,y;
scanf("%d%d",&x,&y);
init(x,y);init(y,x);
}
dfs(1,0);
dfs1(1,0);
/*for (int u=1;u<=n;u++) printf("%d ",f[u]);
printf("\n");*/
num=0;memset(last,-1,sizeof(last));n++;
for (int u=1;u<n;u++)
{
int x,y;
scanf("%d%d",&x,&y);
du[x]++;du[y]++;
init(x,y);init(y,x);
}
dfs(1,0);dfs2(1,0);
for (int u=1;u<=n;u++)
{
if (du[u]==1)
{
int xx=mul(f[u],Inv);
xx=dec(xx,n);
if (s.find(xx)!=s.end())
{
printf("%d\n",u);
break;
}
}
}
return 0;
}