题意:
已知一棵树,求树上一节点 x 的子树中能力比x大的且拥有最大贡献值的子节点
思路:
首先用dfs序将树形变为线性,然后将总结点数n 分块
每块有u=sqrt(n)个节点,将每个块按能力值从小到大排序
ma数组记录 从节点 i 到 块尾 最大贡献值的下标。
查询节点时需要查询其dfs序所包含的区间
区间中,如果有整块存在,则直接二分该块
剩余的区间元素需要查询 整块的前一个块和后一个块(因为sort过,块内顺序会改变,所以要遍历整个块去找寻每个元素是否是我们要查找的元素)
#include<stdio.h>
#include<string.h>
#include<vector>
#include<math.h>
#include<algorithm>
using namespace std;
#define maxn 51000
struct node
{
int a,b,c;
bool operator < (const node &x) const
{
return b < x.b;
}
} op[maxn],k[maxn];
int in[maxn],out[maxn],Time;//dfs序
int which[maxn]; //元素所属块
int u,ma[maxn]; //记录最大贡献值
vector<int>e[maxn];
void dfs(int x, int fa)
{
in[x] = ++Time; //进入的时间戳
k[Time].a=op[x].a;
k[Time].b=op[x].b;
k[Time].c=x;
for(int i = 0; i < e[x].size(); i++) //这里我用的vector存图,也可以用其他方法
{
int cnt = e[x][i];
if(cnt == fa) continue;
dfs(cnt, x);
}
out[x] = Time; //出去的时间戳
}
void build(int n) //分块,排序
{
u=(int)sqrt(n*1.0);
for(int i=1; i<=n; i++)
which[i]=(i-1)/u;
int ans=which[n];
for(int i=0; i<=ans; i++)
{
int x=i*u+1,y=min(x+u-1,Time);
int maxx=-1,book=-1;
sort(k+x,k+y+1);
for(int j=y; j>=x; j--) //预处理块内最大贡献值
{
if(k[j].a>maxx)
{
maxx=k[j].a;
book=j;
}
ma[j]=book; //记录下标
}
}
}
int query(int ovo) //查询
{
int x=which[in[ovo]],y=which[out[ovo]];
node c;
c.b = op[ovo].b;
int maxx=-1,book=-1;
int l=x*u+1,r=min(l+u-1,Time);
for(int i=l; i<=r; i++)
{
if(in[k[i].c]>in[ovo]&&out[k[i].c]<=out[ovo]) //遍历前一个快,确定元素是否在区间内
{
if(k[i].b>op[ovo].b&&k[i].a>maxx)
{
maxx=k[i].a;
book=k[i].c;
}
}
}
x+=1;
l=y*u+1,r=min(l+u-1,Time);
for(int i=l; i<=r; i++)
{
if(in[k[i].c]>in[ovo]&&out[k[i].c]<=out[ovo]) //遍历后一个块,确定元素是否在区间内
{
if(k[i].b>op[ovo].b&&k[i].a>maxx)
{
maxx=k[i].a;
book=k[i].c;
}
}
}
y-=1;
for(int i=x; i<=y; i++) //二分整块
{
int l=i*u+1,r=l+u-1;
int sum= upper_bound(k+l, k+r+1, c) -k; //找到大于该节点的值
if(sum>r) continue;
sum=ma[sum];
if(k[sum].a > maxx)
{
maxx = k[sum].a;
book = k[sum].c;
}
}
return book;
}
void init()
{
u=Time=0;
memset(k,0,sizeof(k));
memset(in,0,sizeof(in));
memset(out,0,sizeof(out));
memset(ma,0,sizeof(ma));
memset(op,0,sizeof(op));
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
init();
int n,m,x;
scanf("%d %d",&n,&m);
op[0].a = -1, op[0].b = -1;
for(int i=0; i<=n; i++)
e[i].clear();
for(int i=1; i<n; i++)
{
scanf("%d %d %d",&x,&op[i].a,&op[i].b);
e[x].push_back(i);
}
dfs(0,0);
build(n);
for(int i=0; i<m; i++)
{
scanf("%d",&x);
int sum=query(x);
printf("%d\n",sum);
}
}
}