【题意】给定一棵树,每个点都有一个a~t的字符,一条路径回文定义为路径上的字符存在一个排列构成回文串,求经过每个点的回文路径数。n<=2*10^5。
【算法】点分治
【题解】状压20位的二进制表示一条路径的字符状态,点分治过程中维护扫描过的路径只须维护状态桶数组,t[i]表示前面状态为i的路径条数。
合并:考虑当前状态为j,要使合并的状态满足条件即i^j=1<<k(0<=k<20)或i^j=0,移项得i=j^(1<<k)或i=j,所以路径数是Σ t [ j^(1<<k) ]+t[j]。
统计每个点:对于当前要处理的重心x,先遍历所有子树得到整个t[]数组,然后对每个子树先删除其在桶里的状态,然后扫一遍贡献子树中每个点,最后将子树的状态加回桶中。
这样可以做到每条路径都贡献到每个点,要特殊处理重心的贡献。
复杂度O(n log n)。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <algorithm>
#include <queue>
#include <vector>
using namespace std;
#define ll long long
const int maxn=200005;
const int maxN=2000005;
int tot,first[maxn],sz[maxn],vis[maxn],sum,root,a[maxn],u,v,n;
ll ans[maxn],t[maxN];
struct node
{
int v,next;
}e[maxn*2];
void insert(int u,int v)
{
tot++;
e[tot].v=v;e[tot].next=first[u];first[u]=tot;
}
void getroot(int x,int fa)
{
sz[x]=1;
bool ok=1;
for(int i=first[x];i;i=e[i].next)
{
int v=e[i].v;
if(fa!=v && vis[v]==0)
{
getroot(v,x);
sz[x]+=sz[v];
if(sz[v]>sum/2)ok=0;
}
}
if(ok==1 && sz[x]>=sum/2)
root=x;
}
void dfs(int x,int fa,int p,int s)
{
t[s^=(1<<a[x])]+=p;
for(int i=first[x];i;i=e[i].next)
{
int v=e[i].v;
if(v!=fa && vis[v]==0)
{
dfs(v,x,p,s);
}
}
}
ll calc(int x,int fa,int s)
{
s^=(1<<a[x]);
ll num=t[s];
for(int i=0;i<20;i++)
{
num+=t[s^(1<<i)];
}
for(int i=first[x];i;i=e[i].next)
{
int v=e[i].v;
if(v!=fa && vis[v]==0)
{
num+=calc(v,x,s);
}
}
ans[x]+=num;
return num;
}
void solve(int x,int s)
{
vis[x]=1;
dfs(x,0,1,0);
ll num=t[0];
for(int i=0;i<20;i++)
{
num+=t[1<<i];
}
for(int i=first[x];i;i=e[i].next)
{
int v=e[i].v;
if(vis[v]==0)
{
dfs(v,x,-1,1<<a[x]);
num+=calc(v,x,0);
dfs(v,x,1,1<<a[x]);
}
}
ans[x]+=num/2;
dfs(x,0,-1,0);
for(int i=first[x];i;i=e[i].next)
{
int v=e[i].v;
if(vis[v]==0)
{
if(sz[v]>sz[x])
sum=s-sz[x];
else
sum=sz[v];
getroot(v,x);
solve(root,sum);
}
}
}
char s[maxn];
int main() {
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
insert(u,v);insert(v,u);
}
scanf("%s",s+1);
for(int i=1;i<=n;i++)
{
a[i]=s[i]-'a';
}
sum=n;
getroot(1,0);
solve(root,sum);
for(int i=1;i<=n;i++)
{
printf("%lld ",ans[i]+1);
}
printf("\n");
return 0;
}