Codeforces-914E-Palindromes in a Tree
题意是给定一棵树,每个节点有一个字母权值,对于每个点
u
u
u ,求出 ans[u]
,ans[u]
的含义为有多少条穿过
u
u
u 的路径,使得路径上的权值能凑出一个回文串。
很显然,每个权值可以表示为一个二进制位。当一条路径的权值的异或和为
0
0
0 (所有字母出现次数都为偶数) 或只有一位置
1
1
1 (只有一个字母的出现次数为奇数)时,这条路径才能拼凑出一个回文串。
一开始想的是通过求到每个点到根的异或和,然后枚举一些所有为回文串的异或和。后面发现这样基本不可做,因为两个点的异或和异或起来并不为路径的异或和,必须再异或上这两个点的LCA的权值才行。这样的话就没办法枚举了。
因此只能使用点分治来进行枚举。
当点分治一个重心时,需要求出它的这些子树对它的贡献。可以分别枚举每棵子树,除被枚举到的这棵子树外,求出其它子树中每个点到重心的异或和,放到一个权值计数器中。然后枚举这棵子树的所有点,设其中一点到子树的根为
r
t
rt
rt ,则异或上
0
0
0 和 1<<i
然后累加上权值计数器记录的值并求和。则求出所有子树对重心的贡献。在枚举的过程中同时还会对被枚举的点产生贡献。贡献为重心的其它子树的点对它产生的。当一个端点为重心本身时,需要特殊处理一下。
通过点分治重心分隔子树的方法,可以使得每个点计算到所有其它点的贡献并求和。
由于枚举所有异或值是
O
(
l
o
g
n
)
O(logn)
O(logn) 的。
整体复杂度是
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n) 。
当需要求每个点的答案,且每个点的答案都有其它点的贡献,且可以通过维护其它点的权值计数器类似的数据结构来计算得到时,可以使用点分治。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 2e5+7;
const int B = 20;
vector<int> adj[N];
char s[N];
bool del[N];
int c[N], siz, rt, cnt[1<<B], sz[N], f[N];
ll ans[N], tsum;
void getroot(int u, int p) {
sz[u]=1; f[u]=0;
for(int v : adj[u]) {
if(del[v]||v==p) continue;
getroot(v, u);
sz[u]+=sz[v];
f[u]=max(f[u], sz[v]);
}
f[u]=max(f[u], siz-sz[u]);
if(f[u]<f[rt]) rt=u;
}
void dfs(int u, int p, int st, int val) {
st ^= c[u];
cnt[st]+=val;
for(int v : adj[u]) {
if(del[v]||v==p) continue;
dfs(v, u, st, val);
}
}
ll cal(int u, int p, int st) {
st^=c[u];
ll sum = cnt[st];
if(st==c[rt]) tsum++;
for(int i=0; i<B; ++i) {
if((st^(1<<i))==c[rt]) tsum++;
sum += cnt[st^(1<<i)];
}
for(int v : adj[u]) {
if(del[v]||v==p) continue;
sum += cal(v, u, st);
}
ans[u]+=sum;
// if(u==7)cout << "sum " << sum << endl;
return sum;
}
void solve(int u) {
dfs(u, 0, 0, 1); //从重心到以重心为根的子树的所有点的异或和贡献。
del[u] = true;
ll sum = 0;
tsum=0;
for(int v : adj[u]) {
if(del[v]) continue;
dfs(v, 0, c[u], -1);
// sum为以u为重心的子树内的点给它的贡献
sum += cal(v, 0, 0);
dfs(v, 0, c[u], 1);
}
dfs(u, 0, 0, -1);
sum+=tsum;
ans[u]+=sum/2;
// printf("%d %I64d\n", u, sum);
for(int v : adj[u]) {
if(del[v]) continue;
f[0]=siz=sz[v];
rt=0;
getroot(v, 0);
solve(rt);
}
}
int main() {
int n;
scanf("%d", &n);
for(int i=1; i<n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
adj[u].push_back(v);
adj[v].push_back(u);
}
scanf("%s", s);
for(int i=0; i<n; ++i) {
int v = s[i]-'a';
c[i+1]=(1<<v);
}
f[0]=siz=n;
rt=0;
getroot(1, 0);
solve(rt);
for(int i=1; i<=n; ++i) {
printf("%I64d%c", ans[i]+1, i==n?'\n':' ');
}
return 0;
}