因为字符随机,所以同样的字符很少,我们对于同样的字符内部暴力两两枚举,看x->y是不是一个回文串。
怎么看呢?蒟蒻我不会…
比赛时胡乱写了一个假的lct维护树上hash值。
一开始没维护sz,它过了…过了…
发现后改过,它RE了…缘来是我写挂了qaq
改对过了美滋滋,然而elijahqi巨佬告诉我你太naive了…你这样rev之后hash值根本不对…囧
然而过了,可能因为数据随机吧qaq,反正正解也是爆搜【逃】
7.3upd::直接树上hash就好了,两个方向的Hash值,需要除法,所以必须得模一个质数qaq。
树上Hash
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define N 100010
#define k1 11119
#define mod 1004535809
inline char gc(){
static char buf[1<<16],*S,*T;
if(S==T){T=(S=buf)+fread(buf,1,1<<16,stdin);if(T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
return x*f;
}
int n,fa[N][20],h[N],num=0,dep[N],Log[N],hs1[N],hs2[N],bin[N],inv[N],v[N];
struct edge{
int to,next;
}data[N<<1];
vector<int>a[N];
inline int ksm(int x,int k){
int res=1;for(;k;k>>=1,x=(ll)x*x%mod) if(k&1) res=(ll)res*x%mod;return res;
}
inline void dfs(int x){
hs1[x]=(hs1[fa[x][0]]+(ll)v[x]*bin[dep[x]])%mod;
hs2[x]=((ll)hs2[fa[x][0]]*k1+v[x])%mod;
for(int i=1;i<=Log[n];++i){
if(!fa[x][i-1]) break;
fa[x][i]=fa[fa[x][i-1]][i-1];
}for(int i=h[x];i;i=data[i].next){
int y=data[i].to;if(y==fa[x][0]) continue;
fa[y][0]=x;dep[y]=dep[x]+1;dfs(y);
}
}
inline int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
int d=dep[x]-dep[y];
for(int i=0;i<=Log[d];++i)
if(d>>i&1) x=fa[x][i];
if(x==y) return x;
for(int i=Log[n];i>=0;--i)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
inline int calc(int x,int y){
int t=lca(x,y);
int res=(hs2[x]-(ll)hs2[fa[t][0]]*bin[dep[x]-dep[t]+1])%mod;
res+=(ll)(hs1[y]-hs1[t])*inv[dep[t]]%mod*bin[dep[x]-dep[t]]%mod;
res=(res%mod+mod)%mod;return res;
}
int main(){
// freopen("a.in","r",stdin);
int tst=read();n=100000;int ik1=ksm(k1,mod-2);Log[0]=-1;
for(int i=1;i<=n;++i) Log[i]=Log[i>>1]+1;bin[0]=inv[0]=1;
for(int i=1;i<=n;++i) bin[i]=(ll)bin[i-1]*k1%mod,inv[i]=(ll)inv[i-1]*ik1%mod;
while(tst--){
n=read();ll ans=n;memset(fa,0,sizeof(fa));memset(h,0,sizeof(h));num=0;
for(int i=1;i<=n;++i) v[i]=read(),a[v[i]].push_back(i);
for(int i=1;i<n;++i){
int x=read(),y=read();
data[++num].to=y;data[num].next=h[x];h[x]=num;
data[++num].to=x;data[num].next=h[y];h[y]=num;
}dep[1]=1;dfs(1);
for(int i=1;i<=n;++i)
for(int j=0;j<a[i].size();++j)
for(int k=j+1;k<a[i].size();++k)
if(calc(a[i][j],a[i][k])==calc(a[i][k],a[i][j])) ++ans;
printf("%lld\n",ans);
for(int i=1;i<=n;++i) a[i].clear();
}return 0;
}
骗分lct
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define N 100010
#define k1 11117
#define ull unsigned long long
inline char gc(){
static char buf[1<<16],*S,*T;
if(S==T){T=(S=buf)+fread(buf,1,1<<16,stdin);if(T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
return x*f;
}
int n,fa[N],c[N][2],sz[N],v[N],q[N];
ull hs[N],bin[N];ll ans=0;
bool rev[N];
vector<int>a[N];
inline void update(int p){
int l=c[p][0],r=c[p][1];
sz[p]=sz[l]+sz[r]+1;
hs[p]=hs[l]*bin[sz[r]+1]+v[p]*bin[sz[r]]+hs[r];
}
inline void dorev(int x){
if(!x) return;rev[x]^=1;swap(c[x][0],c[x][1]);update(x);
}
inline bool isroot(int x){
return x!=c[fa[x]][0]&&x!=c[fa[x]][1];
}
inline void pushdown(int p){
if(!rev[p]) return;rev[p]=0;
dorev(c[p][0]);dorev(c[p][1]);
}
inline void rotate(int x){
int y=fa[x],z=fa[y],l=x==c[y][1],r=l^1;
if(!isroot(y)) c[z][y==c[z][1]]=x;
fa[c[x][r]]=y;fa[y]=x;fa[x]=z;
c[y][l]=c[x][r];c[x][r]=y;update(y);update(x);
}
inline void splay(int x){
int top=0;q[++top]=x;
for(int xx=x;!isroot(xx);xx=fa[xx]) q[++top]=fa[xx];
while(top) pushdown(q[top--]);
while(!isroot(x)){
int y=fa[x],z=fa[y];
if(!isroot(y)){
if(x==c[y][1]^y==c[z][1]) rotate(x);
else rotate(y);
}rotate(x);
}
}
inline void access(int x){
int y=0;while(x){splay(x);c[x][1]=y;update(x);y=x;x=fa[x];}
}
inline void makeroot(int x){
access(x);splay(x);dorev(x);
}
inline void link(int x,int y){
makeroot(x);fa[x]=y;
}
inline ull cal(int x,int y){
makeroot(x);access(y);splay(y);return hs[y];
}
int main(){
// freopen("a.in","r",stdin);
int tst=read();
while(tst--){
n=read();ans=n;memset(fa,0,sizeof(fa));memset(c,0,sizeof(c));bin[0]=1;
memset(rev,0,sizeof(rev));memset(hs,0,sizeof(hs));memset(sz,0,sizeof(sz));
for(int i=1;i<=n;++i) v[i]=read(),a[v[i]].push_back(i),bin[i]=bin[i-1]*k1;
for(int i=1;i<n;++i){
int x=read(),y=read();link(x,y);
}for(int i=1;i<=n;++i)
for(int j=0;j<a[i].size();++j)
for(int k=j+1;k<a[i].size();++k)
if(cal(a[i][j],a[i][k])==cal(a[i][k],a[i][j])) ++ans;
printf("%lld\n",ans);
for(int i=1;i<=n;++i) a[i].clear();
}return 0;
}