题解:
这种题目一看就和
d
f
s
dfs
dfs序有关。考虑对于一条链,哪些链能够覆盖它。分两种情况,为了方便,链的两端点分别为
x
,
y
x,y
x,y,且
i
n
[
x
]
<
=
i
n
[
y
]
in[x]<=in[y]
in[x]<=in[y]。
1、
i
n
[
x
]
<
=
i
n
[
y
]
in[x]<=in[y]
in[x]<=in[y]且
i
n
[
y
]
<
=
o
u
t
[
x
]
in[y]<=out[x]
in[y]<=out[x],这种就是一条直链,那么能够覆盖它的就是一个端点在
y
y
y子树内,另外一个端点在
x
x
x外面以及除
y
y
y的子树内的链。
2、另外一种就是倒
V
V
V形的链,此时能覆盖它的是一端点在
x
x
x子树内,另外一个端点在
y
y
y子树内的链。
然后发现每条链可以看做一个点
(
i
n
[
x
]
,
i
n
[
y
]
)
(in[x],in[y])
(in[x],in[y]),每个盘子的贡献区间是一个或两个矩阵,那么就可以用扫描线的方法做,需要支持以下操作:1、
[
L
,
R
]
[L,R]
[L,R]每个位置插入一个数。2、查询
x
x
x位置上的第
k
k
k小值。3、
[
L
,
R
]
[L,R]
[L,R]每个位置删除一个数。这就是整体二分的入门题,和BZOJ3110比较像,直接做。
数据结构的功能其实是区间修改和单点查询,用树状数组就好了,不要像我这么傻。
不知道为什么代码这么长……
代码:
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=40010;
const int inf=2147483647;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
int n,P,Q,dfn=0,in[Maxn],out[Maxn],fa[Maxn][16],dep[Maxn],val[Maxn];
struct Edge{int y,next;}e[Maxn<<1];
int last[Maxn],len=0;
void ins(int x,int y)
{
int t=++len;
e[t].y=y;e[t].next=last[x];last[x]=t;
}
struct Opt{int type,x,y,z,t;}O[Maxn*5],q1[Maxn*5],q2[Maxn*5];int lo=0;
bool cmp(Opt a,Opt b)
{
if(a.t!=b.t)return a.t<b.t;
return a.type>b.type;
}
/*
1 [x,y] 插入z
-1 [x,y] 删除z
0 询问x位置 第z小 id为y
*/
void dfs(int x,int ff)
{
in[x]=++dfn;dep[x]=dep[ff]+1;fa[x][0]=ff;
for(int i=1;(1<<i)<=dep[x];i++)fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=last[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==ff)continue;
dfs(y,x);
}
out[x]=dfn;
}
int jump(int x,int y)
{
for(int i=15;i>=0;i--)
if((1<<i)<=dep[x]&&dep[fa[x][i]]>dep[y])x=fa[x][i];
return x;
}
/*Segment*/
struct Seg{int l,r,lc,rc,c,tag;}tr[Maxn<<1];
int tot=0;
void Add(int x,int v)
{
tr[x].c+=(tr[x].r-tr[x].l+1)*v;
tr[x].tag+=v;
}
void down(int x)
{
if(tr[x].tag)
{
Add(tr[x].lc,tr[x].tag),Add(tr[x].rc,tr[x].tag);
tr[x].tag=0;
}
}
void build(int l,int r)
{
int x=++tot;
tr[x].l=l;tr[x].r=r;tr[x].c=tr[x].tag=0;
if(l==r)return;
int mid=l+r>>1;
tr[x].lc=tot+1,build(l,mid);
tr[x].rc=tot+1,build(mid+1,r);
}
void add(int x,int l,int r,int v)
{
if(tr[x].l==l&&tr[x].r==r){Add(x,v);return;}
int mid=tr[x].l+tr[x].r>>1,lc=tr[x].lc,rc=tr[x].rc;
down(x);
if(r<=mid)add(lc,l,r,v);
else if(l>mid)add(rc,l,r,v);
else add(lc,l,mid,v),add(rc,mid+1,r,v);
tr[x].c=tr[lc].c+tr[rc].c;
}
int query(int x,int p)
{
if(tr[x].l==tr[x].r)return tr[x].c;
int mid=tr[x].l+tr[x].r>>1,lc=tr[x].lc,rc=tr[x].rc;
down(x);
if(p<=mid)return query(lc,p);
return query(rc,p);
}
int ans[Maxn];
void solve(int ql,int qr,int l,int r)
{
if(l==r)
{
for(int i=ql;i<=qr;i++)
if(O[i].type==0)ans[O[i].y]=val[l];
return;
}
int mid=l+r>>1,l1=0,l2=0;
for(int i=ql;i<=qr;i++)
{
if(O[i].type==0)
{
int t=query(1,O[i].x);
if(t>=O[i].z)q1[++l1]=O[i];
else q2[++l2]=O[i],q2[l2].z-=t;
}
else
{
if(O[i].z<=val[mid])add(1,O[i].x,O[i].y,O[i].type),q1[++l1]=O[i];
else q2[++l2]=O[i];
}
}
int tmp=0;
for(int i=1;i<=l1;i++)
{
tmp++;
O[ql+tmp-1]=q1[i];
if(q1[i].type!=0)add(1,q1[i].x,q1[i].y,-q1[i].type);
}
for(int i=1;i<=l2;i++)tmp++,O[ql+tmp-1]=q2[i];
solve(ql,ql+l1-1,l,mid),solve(ql+l1,qr,mid+1,r);
}
int main()
{
n=read(),P=read(),Q=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
ins(x,y),ins(y,x);
}
dep[0]=-1;dfs(1,0);
for(int i=1;i<=P;i++)
{
int x=read(),y=read(),z=read();val[i]=z;
if(in[x]>in[y])swap(x,y);
int t,p=jump(y,x);
if(in[x]<=in[y]&&in[y]<=out[x])
{
if(in[p]>1)
{
t=++lo;O[t].type=1,O[t].x=in[y],O[t].y=out[y],O[t].z=z,O[t].t=1;
t=++lo;O[t].type=-1,O[t].x=in[y],O[t].y=out[y],O[t].z=z,O[t].t=in[p]-1;
}
if(out[p]<n)
{
t=++lo;O[t].type=1,O[t].x=out[p]+1,O[t].y=n,O[t].z=z,O[t].t=in[y];
t=++lo;O[t].type=-1,O[t].x=out[p]+1,O[t].y=n,O[t].z=z,O[t].t=out[y];
}
}
else
{
t=++lo;O[t].type=1,O[t].x=in[y],O[t].y=out[y],O[t].z=z,O[t].t=in[x];
t=++lo;O[t].type=-1,O[t].x=in[y],O[t].y=out[y],O[t].z=z,O[t].t=out[x];
}
}
sort(val+1,val+1+P);
for(int i=1;i<=Q;i++)
{
int x=read(),y=read(),k=read();
if(in[x]>in[y])swap(x,y);
int t=++lo;O[t].type=0,O[t].x=in[y],O[t].y=i,O[t].z=k,O[t].t=in[x];
}
sort(O+1,O+1+lo,cmp);
build(1,n);
solve(1,lo,1,P);
for(int i=1;i<=Q;i++)printf("%d\n",ans[i]);
}