题目
简化题意,
n(n<=1e6)个点,1为根的有根树
点i对应一个权值wi,wi在[0,n]之间
对于每个i,问在其子树中有多少种深度d,
满足该深度下存在两个点a、b,使得(x是给出的值,也在[0,n]之间)
对于这个i,计其答案为ans[i],
要求输出对998244353取模的值
思路来源
官方题解
题解
如果n=1e5,可以用unorder_map的双log的启发式合并做,这也是C.虫洞(中等)的做法
启发式合并如果每一层有x个点,就动态开长为x的数组,也可以压到复杂度O(nlogn),
但是常数巨大,所以卡不过D,怎么写怎么T,最后无奈来补O(n)做法
bfs的过程中,一边bfs一边尺取,
tong[x]维护值为x时当前出现的最大的bfs序下标,
能更新当且仅当v所需要的tong[v^x]和v在同一层
区间包含,增序扫r,维护当前最大左端点maxl,让maxl只增不减,即可避免区间嵌套
考虑要对lca(v1,gv1)单点+1,lca(v2,gv2)单点+1,对lca(lca(v1,gv1),lca(v2,gv2))单点-1,
tarjan离线,分两次离线,第一次dfs1把+1的LCA离线出来,第二次dfs2再去离线-1的LCA
最后用dfs3做一遍树上前缀和,顺便统计答案
总复杂度是O(n)的,但是扫了将近10次序列,所以常数还是很大
vector建图怎么改怎么T,链式前向星跑到了600ms上下
代码
#include<bits/stdc++.h>
using namespace std;
namespace fastIO
{
static char buf[100000],*h=buf,*dp=buf;//缓存开大可减少读入时间,看题目给的空间
#define gc h==dp&&(dp=(h=buf)+fread(buf,1,100000,stdin),h==dp)?EOF:*h++//不能用fread则换成getchar
template<typename T>
inline void read(T&x)
{
int f = 1;x = 0;
register char c(gc);
while(c>'9'||c<'0'){
if(c == '-') f = -1;
c=gc;
}
while(c<='9'&&c>='0')x=(x<<1)+(x<<3)+(c^48),c=gc;
x *= f;
}
template<typename T>
void output(T x)
{
if(x<0){putchar('-');x=~(x-1);}
static int s[20],top=0;
while(x){s[++top]=x%10;x/=10;}
if(!top)s[++top]=0;
while(top)putchar(s[top--]+'0');
}
}
using namespace fastIO;
#define pb push_back
#define fi first
#define se second
typedef long long ll;
typedef pair<int,int> P;
const int N=1e6+10,mod=998244353;
struct node{
int lca,pre;
}add[N];
int del[N],c2;
int n,x,v,w[N],c;
int q[N],t,d[N],tong[N];
int ans[N],res;
int head[N],nex[N],to[N],cnt;
int par[N];
bool vis[N],fir[N];
int find(int x){
return par[x]==x?x:par[x]=find(par[x]);
}
void add_edge(int x,int y){
nex[++cnt]=head[x];
to[cnt]=y;
head[x]=cnt;
}
vector<P>Q[N];
void init(int n){
for(int i=1;i<=n;++i){
par[i]=i;
Q[i].clear();
vis[i]=0;
}
}
void unite(int x,int y){
x=find(x),y=find(y);
if(x==y)return;
par[y]=x;
}
void dfs1(int u){
vis[u]=1;
for(int i=head[u];i;i=nex[i]){
int v=to[i];
dfs1(v);
unite(u,v);
}
for(auto &x:Q[u]){
int v=x.fi,id=x.se;
if(!vis[v])continue;
add[id].lca=find(v);
}
}
void dfs2(int u){
vis[u]=1;
for(int i=head[u];i;i=nex[i]){
int v=to[i];
dfs2(v);
unite(u,v);
}
for(auto &x:Q[u]){
int v=x.fi,id=x.se;
if(!vis[v])continue;
del[id]=find(v);
}
}
void dfs3(int u){
for(int i=head[u];i;i=nex[i]){
int v=to[i];
dfs3(v);
ans[u]+=ans[v];
}
res=(res+(u^(n-ans[u])))%mod;
}
void bfs(int z){
q[++t]=z;
d[z]=1;
int mxl,now,a,v,oth,b;
for(int s=1;s<=t;++s){
a=q[s],v=w[a],oth=v^x;
if(!fir[d[a]]){
mxl=s-1;
now=-1;
fir[d[a]]=1;
}
if(tong[oth]>mxl){
b=q[tong[oth]];
add[++c]={-1,now};
Q[a].pb(P(b,c));
Q[b].pb(P(a,c));
mxl=tong[oth];
now=c;
}
tong[v]=s;
for(int i=head[a];i;i=nex[i]){
int v=to[i];
if(d[v])continue;
q[++t]=v;
d[v]=d[a]+1;
}
}
}
int main(){
read(n),read(x);
for(int i=2;i<=n;++i){
read(v);
add_edge(v,i);
}
for(int i=1;i<=n;++i){
read(w[i]);
par[i]=i;
}
bfs(1);
dfs1(1);
init(n);
for(int i=1;i<=c;++i){
int lc=add[i].lca,pr=add[i].pre;
ans[lc]++;
if(pr!=-1){
int rc=add[pr].lca;
++c2;
Q[rc].pb(P(lc,c2));
Q[lc].pb(P(rc,c2));
}
}
dfs2(1);
for(int i=1;i<=c2;++i){
ans[del[i]]--;
}
dfs3(1);
printf("%d\n",res);
return 0;
}
/*
5 0
3 1 5 1
0 0 0 0 0
*/