题意:
有两颗树节点数都为 n n n的树 A A A与 B B B,每个点都有一个权值,有一个长度为 k k k的序列,问这个序列中恰好删除一个元素,并且其他元素在树 A A A中的 l c a lca lca的权值大于在 B B B中的 l c a lca lca的权值的删除方法有多少种。
原题很拗口,以至于一开始方向错了,通俗解释一下,加入序列是 [ 1 , 3 , 6 , 7 ] [1,3,6,7] [1,3,6,7],假设删去 3 3 3,并且有 v l c a A ( 1 , 6 , 7 ) > v l c a B ( 1 , 6 , 7 ) v_{lca_{A}(1,6,7)}>v_{lca_{B}(1,6,7)} vlcaA(1,6,7)>vlcaB(1,6,7),此时删除3就是一种方法。
方法:
本题关键在如何快速求得多个节点的 l c a lca lca,记录两种方法:
法一:
多个点的 l c a lca lca就是这些点中 d f s dfs dfs序最小和最大的那两个点的 l c a lca lca
法二(仅适用于此):
求一个序列的去掉一个点的 l c a lca lca,即这个元素前面的元素的 l c a lca lca与后面的元素的 l c a lca lca的 l c a lca lca, l c a lca lca可以叠加,不可以分离,因此可以求得这个序列的前缀 l c a lca lca与后缀 l c a lca lca,然后 i i i位置的 l c a lca lca就是 l c a ( p r e [ i − 1 ] , s u f [ i + 1 ] ) lca(pre[i-1],suf[i+1]) lca(pre[i−1],suf[i+1]),当然,这种方法也可以求得去掉一个连续段的 l c a lca lca。
d f s dfs dfs序 l c a lca lca代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
int read()
{
int ret=0,base=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-') base=-1;
ch=getchar();
}
while(isdigit(ch))
{
ret=(ret<<3)+(ret<<1)+ch-48;
ch=getchar();
}
return ret*base;
}
vector<int>v1,v2;
vector<int>go1[100005],go2[100005];
int ans,depth1[100005],depth2[100005],id1[100005],id2[100005],id_,id__;;
int n,k,x[100005],lca[100005],tree1[100005],tree2[100005],f1[100005][21],f2[100005][21];
void dfs1(int u)
{
id1[u]=++id_;
depth1[u]=depth1[f1[u][0]]+1;
for(int i=1;depth1[u]-(1<<i)>=1;i++) f1[u][i]=f1[f1[u][i-1]][i-1];
for(auto i:go1[u])
{
if(i!=f1[u][0]) dfs1(i);
}
}
void dfs2(int u)
{
id2[u]=++id__;
depth2[u]=depth2[f2[u][0]]+1;;
for(int i=1;depth2[u]-(1<<i)>=1;i++) f2[u][i]=f2[f2[u][i-1]][i-1];
for(auto i:go2[u])
{
if(i!=f2[u][0]) dfs2(i);
}
}
int lca_1(int x,int y)
{
if(depth1[x]<depth1[y]) swap(x,y);
for(int i=17;i>=0;i--)
{
if(depth1[x]-(1<<i)>=depth1[y]) x=f1[x][i];
}
if(x==y) return x;
for(int i=17;i>=0;i--)
{
if(f1[x][i]!=f1[y][i])
{
x=f1[x][i];
y=f1[y][i];
}
}
return f1[x][0];
}
int lca_2(int x,int y)
{
if(depth2[x]<depth2[y]) swap(x,y);
for(int i=17;i>=0;i--)
{
if(depth2[x]-(1<<i)>=depth2[y]) x=f2[x][i];
}
if(x==y) return x;
for(int i=17;i>=0;i--)
{
if(f2[x][i]!=f2[y][i])
{
x=f2[x][i];
y=f2[y][i];
}
}
return f2[x][0];
}
bool check(int t)
{
int ret1,ret2,min1,max1;
min1=t==v1[0]?v1[1]:v1[0];
max1=t==v1[(int)v1.size()-1]?v1[(int)v1.size()-2]:v1[(int)v1.size()-1];
ret1=lca_1(min1,max1);
min1=t==v2[0]?v2[1]:v2[0];
max1=t==v2[(int)v2.size()-1]?v2[(int)v2.size()-2]:v2[(int)v2.size()-1];
ret2=lca_2(min1,max1);
return tree1[ret1]>tree2[ret2];
}
int main()
{
n=read();k=read();
for(int i=1;i<=k;i++)
{
x[i]=read();
v1.push_back(x[i]);
v2.push_back(x[i]);
}
for(int i=1;i<=n;i++) tree1[i]=read();
for(int i=2;i<=n;i++)
{
f1[i][0]=read();
go1[f1[i][0]].push_back(i);
go1[i].push_back(f1[i][0]);
}
for(int i=1;i<=n;i++) tree2[i]=read();
for(int i=2;i<=n;i++)
{
f2[i][0]=read();
go2[f2[i][0]].push_back(i);
go2[i].push_back(f2[i][0]);
}
dfs1(1);dfs2(1);
sort(v1.begin(),v1.end(),[](int temp1,int temp2){return id1[temp1]<id1[temp2];});
sort(v2.begin(),v2.end(),[](int temp1,int temp2){return id2[temp1]<id2[temp2];});
for(int i=1;i<=k;i++)
{
if(check(x[i])) ans++;
}
cout<<ans;
return 0;
}