采用的树哈希函数是:
d p x = w x × ∑ y ∈ x d p y 2 + w x 2 \Large dp_x=w_x\times \sum_{y\in x}dp_y^2+w_x^2 dpx=wx×y∈x∑dpy2+wx2
发现从 x x x 到 y y y 时只有 x x x 与 y y y 的哈希值会变化,分别维护即可
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read(){int x=0,f=1;char ch=getchar(); while(ch<'0'||
ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}return x*f;}
#define Z(x) (x)*(x)
#define pb push_back
//mt19937 rand(time(0));
//mt19937_64 rand(time(0));
//srand(time(0));
#define N 100010
//#define M
#define mo (int)(1e9+123)
int n, m, i, j, k, T;
int pos, mx, cnt, h[N], w[N], dp[N], f[N], u, v;
map<int, int>mp;
vector<int>G[N];
int Mod(int a) {
return (a%mo+mo)%mo;
}
void add(int x, int k) {
mp[x]+=k;
if(mp[x]==1 && k==1) ++cnt;
if(mp[x]==0 && k==-1) --cnt;
// printf("# %lld (%lld): %lld\n", x, mp[x], cnt);
}
void dfs1(int x, int fa) {
// int s1, s2=0;
w[x]=1;
for(int y : G[x]) {
if(y==fa) continue;
dfs1(y, x);
w[x]+=w[y]; f[x]=Mod(f[x]+dp[y]*dp[y]);
}
dp[x]=Mod(w[x]*f[x]%mo+w[x]*w[x]%mo);
add(dp[x], 1);
// printf("%lld : %lld\n", x, dp[x]);
}
void dfs2(int x, int fa) {
int xw, xf, xdp, yw, yf, ydp;
for(int y : G[x]) {
if(y==fa) continue;
// printf("del [%lld] : %lld\n", x, dp[x]);
add(dp[x], -1); xdp=dp[x]; xw=w[x]; xf=f[x];
w[x]=w[x]-w[y]; f[x]=Mod(f[x]-dp[y]*dp[y]);
dp[x]=(w[x]*f[x]%mo+w[x]*w[x]%mo);
// printf("ins [%lld] %lld : %lld\n", x, w[x], dp[x]);
add(dp[x], 1);
// printf("del [%lld] : %lld\n", y, dp[y]);
add(dp[y], -1); ydp=dp[y]; yw=w[y]; yf=f[y];
w[y]=n; f[y]=Mod(f[y]+dp[x]*dp[x])%mo;
dp[y]=(w[y]*f[y]%mo+w[y]*w[y]%mo);
// printf("ins [%lld] : %lld\n", y, dp[y]);
add(dp[y], 1);
// printf("%lld : %lld\n", y, cnt);
// for(auto t=mp.begin(); t!=mp.end(); ++t) printf("%lld ", t);
if(cnt>mx) mx=cnt, pos=y;
dfs2(y, x);
add(dp[x], -1); add(dp[y], -1);
dp[x]=xdp; w[x]=xw; f[x]=xf; add(dp[x], 1);
dp[y]=ydp; w[y]=yw; f[y]=yf; add(dp[y], 1);
}
}
signed main()
{
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
// T=read();
// while(T--) {
//
// }
n=read();
for(i=1, k=1; i<n; ++i) {
u=read(); v=read();
G[u].pb(v); G[v].pb(u);
}
dfs1(1, 0);
mx=cnt; pos=1;
dfs2(1, 0);
printf("%lld", pos);
return 0;
}