题目大意:给出一棵树,树边上有一个字符,问每个点的子树中最长的合法路径长度。
所谓的合法路径长度就是给路径中的所有字符重新组合后可以是回文串。
题解:dsu on the tree如果是回文串,最多有1个字符是单数个。我们将每个字符压成二进制中的一位,那么对于每个位置维护根到该点路径上的异或值,如果两个点的异或值是0或者2^x那么就是一个合法的方案。所以我们可以对于每个异或值维护深度最深的点的深度。
每次统计一个节点的答案,依次加入他的每个儿子,注意一定是先计算答案,再加入。保证选中的点不是来自一棵子树。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#define N 2000003
#define M 22
#define inf 1000000000
using namespace std;
int mp[N*6],tot,nxt[N],point[N],v[N],c[N],val[N],size[N],son[N];
int deep[N],ans[N],mx,mark[N],n;
void add(int x,int y,int z){
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=z;
}
void solve(int x,int fa)
{
size[x]=1; deep[x]=deep[fa]+1;
for (int i=point[x];i;i=nxt[i]) {
if (v[i]==fa) continue;
val[v[i]]=val[x]^c[i];
solve(v[i],x);
size[x]+=size[v[i]];
if (size[son[x]]<size[v[i]]) son[x]=v[i];
}
}
void get_ans(int dep,int x,int fa)
{
mx=max(mx,mp[val[x]]+deep[x]-2*dep);
for (int i=0;i<=21;i++) mx=max(mx,mp[val[x]^(1<<i)]+deep[x]-2*dep);
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa&&!mark[v[i]]) get_ans(dep,v[i],x);
}
void change(int x,int fa)
{
mp[val[x]]=max(mp[val[x]],deep[x]);
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa&&!mark[v[i]]) change(v[i],x);
}
void init(int x,int fa)
{
mp[val[x]]=-inf;
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa&&!mark[v[i]]) init(v[i],x);
}
void dfs(int x,int fa,bool k)
{
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa&&v[i]!=son[x]) dfs(v[i],x,0),ans[x]=max(ans[x],ans[v[i]]);
if (son[x]) dfs(son[x],x,1),ans[x]=max(ans[x],ans[son[x]]),mark[son[x]]=1;
for (int i=point[x];i;i=nxt[i])
if (v[i]!=son[x]&&v[i]!=fa) get_ans(deep[x],v[i],x),change(v[i],x);
mx=max(mx,mp[val[x]]-deep[x]);
for (int i=0;i<=21;i++) mx=max(mx,mp[val[x]^(1<<i)]-deep[x]);
mp[val[x]]=max(mp[val[x]],deep[x]);
ans[x]=max(ans[x],mx);
if(son[x]) mark[son[x]]=0;
if (!k) init(x,fa),mx=-inf;
}
int main()
{
freopen("a.in","r",stdin);
scanf("%d",&n);
for (int i=2;i<=n;i++) {
int x; char c[3];
scanf("%d%s",&x,c);
add(x,i,1<<(c[0]-'a'));
}
for (int i=0;i<=(1<<M);i++) mp[i]=-inf;
solve(1,0);
dfs(1,0,0);
for (int i=1;i<=n;i++) printf("%d ",ans[i]);
}