给出两棵二叉树,求这两棵树上有多少相同的子树。
相同的子树指树A中的子树a和树B中的子树b完全相同,二叉树的相同定义为树上总节点个数相同,根节点孩子数相同,而且两棵子树分别相同。当然,孩子节点是有先后顺序的!
输入格式:
第一行输入两个正整数N,M表示两棵二叉树的节点个数。
第2到N+1行每行两个整数X、Y,第i+1行表示树A上节点i的左孩子和右孩子分别是谁。若没有某个孩子则用-1表示。
第N+2到N+M+1行每行两个正整数X、Y,第i+N+1行表示树B上节点i的左孩子和右孩子分别是谁。若没有某个孩子同样用-1表示。
输出格式:
一行输出一个整数ANS表示相同子树的个数。
样例输入:
样例输出:
数据范围:
对于20%的数据1 ≤ N,M ≤ 100;
对于40%的数据1≤N,M ≤5,000;
对于100%的数据1≤ N,M ≤ 100,000。
时间限制:
1s
空间限制:
256M
水水的沙茶题把本沙茶X了。 dfn数组没开够啊,我记得我专门想了这个的啊!!!!
以前也是被这个DFN数组弄死了。太傻了。以后直接MAX_N*4好了也就不必纠结这种东西了。太沙茶了。。。
这题比较自然的想法是括号序列哈希,哈希出一颗树的所有子树的值用个map记录出现次数,然后再用另外一棵树上的子树哈希值去map里看有多少个。。由于儿子节点不能交换,所以每个儿子进去出来之后打一个独特的标记区分左右儿子。
由于怕冲突,我写的相当。。。 模两个再map<pair<LL,LL>,int>
当然还有可以无冲突的方法,用1...n一个数对应一个树的形态,然后两个子树的形态决定大子树的形态。用map映射,若有map[pair<i,j>]就直接取,否则map[pair<i,j>]=++total,这样就保证无冲了。。
虽然当时我想过第二种方法的。。但是莫名其妙的就没想了就直接想哈希了。。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#define rep(i,l,r) for (int i=l;i<=r;++i)
typedef long long LL;
typedef std::pair<LL,LL> PLL;
int getx(){
char c;int x;bool pd=0;
for (c=getchar();c!='-'&&(c<'0'||c>'9');c=getchar());
if (c=='-') pd=1,c=getchar();
for (x=0;c>='0'&&c<='9';c=getchar())
x=(x<<3)+(x<<1)+c-'0';
return pd?-x:x;
}
const int MAX_N=100005;
const LL base=127;
const LL Lmod1=389,Lmod2=223,Rmod1=431,Rmod2=587;
const LL MOD1=1000000007,MOD2=1743157613;
struct bintree{
int ch[MAX_N][2],n;
int lv[MAX_N],rv[MAX_N],Time;
int dfn[MAX_N*4];
void dfs(int v){
dfn[++Time]=Lmod1;
lv[v]=Time;
rep(k,0,1){
if (ch[v][k]!=-1) dfs(ch[v][k]);
dfn[++Time]=!k?Lmod2:Rmod2;
}
dfn[++Time]=Rmod1;
rv[v]=Time;
}
LL hash1[MAX_N*4],hash2[MAX_N*4];
LL pow1[MAX_N*4],pow2[MAX_N*4];
void calchash(){
pow1[0]=pow2[0]=1;
rep(i,1,Time){
pow1[i]=pow1[i-1]*base%MOD1;
pow2[i]=pow2[i-1]*base%MOD2;
hash1[i]=(hash1[i-1]*base+dfn[i])%MOD1;
hash2[i]=(hash2[i-1]*base+dfn[i])%MOD2;
}
}
PLL f(int L,int R){
LL h1=(hash1[R]-hash1[L-1]*pow1[R-L+1])%MOD1,
h2=(hash2[R]-hash2[L-1]*pow2[R-L+1])%MOD2;
if (h1<0) h1+=MOD1;if (h2<0) h2+=MOD2;
return PLL(h1,h2);
}
void init(int _n){
n=_n;
rep(i,1,n) ch[i][0]=getx(),ch[i][1]=getx();
dfs(1);
calchash();
}
}a,b;
std::map<PLL,int> map;
int n,m;
LL ans=0;
int main(){
n=getx(),m=getx();
a.init(n);b.init(m);
rep(i,1,n) map[a.f(a.lv[i],a.rv[i])]++;
rep(i,1,m) ans+=map[b.f(b.lv[i],b.rv[i])];
printf("%lld\n",ans);
}