Gym 101138J Valentina and the Gift Tree(树链剖分)
树链剖分,线段树
第一次学树链剖分。。就搞了这么难一题。。各种代码看了好几天才明白。。
传送门:CodeForce
传送门:HustOJ
要是想要测试数据和别人的代码,可以去这个OJ(不要干坏事哦~)
传送门:Hackerearth
题意
建议读原题。
一棵树,100000节点,树上每个节点有权值,整数。有500000个查询,每次查询给树上两点。求树上两点形成的路径上(包括两端点),最大连续子区间权值和。
关于连续子区间权值和,比如第一个样例的第一个查询,路径权值是2 -1 -2 5。连续子区间权值和是5。
思路
10w节点,50w查询。。肯定要树链剖分。。关于最大连续子区间权值和,用线段树维护。
先说说线段树维护最大子区间权值和。
维护四个信息,最大前缀,最大后缀,最大子区间,区间和。
区间合并时,大区间最大前缀=max(左子最大前缀,左子区间和+右子最大前缀)。后缀同理。
大区间最大子区间=max(左子最大子区间,右子最大子区间,左子最大后缀+右子最大前缀)
struct STREE
{
//维护最大前缀,最大后缀,最大子区间,区间和
LL MPrefix, MPostfix, Sum, MaxValue;
STREE() { MPostfix=MPrefix=Sum=0;MaxValue=-loo; }
STREE(LL l, LL r, LL s, LL m) { MPrefix=l;MPostfix=r; Sum=s;MaxValue=m; }
STREE operator + (const STREE &a)const
{
STREE New;
New.MPrefix=max(MPrefix, Sum+a.MPrefix);
New.MPostfix=max(a.MPostfix, a.Sum+MPostfix);
New.Sum=Sum+a.Sum;
New.MaxValue=max(a.MaxValue, max(MaxValue, MPostfix+a.MPrefix));
return New;
}
}Stree[MAXN<<2];
然后是树链剖分。树链剖分其实就是将一棵树节点重新编号,存到数据结构(比如线段树)里面。
为什么要重新编号呢?因为线段树可以区间更新、区间查询,而如果不重新给树编号,那么我们无法最大程度的利用区间的特性。
剖分后,有重链,轻链的说法。重链就是由大部分节点构成的链。
我们通过重新编号,使得重链在线段树里面连续保存,这样对树更新时,占了大部分节点的重链就可以区间更新,而其他轻链进行单点更新,加快速度。
重新编号的方法就是DFS,有条件的DFS。
关于树链剖分的讲解:
我的理解
第一次DFS时,获取的信息有深度,父节点,子树节点个数(SonAmount),重儿子编号。
void dfs1(int n)//调用之前初始化Depth[1]=1
{
SonAmount[n]=1;
for(int i=0;i<G[n].size();i++)
{
int to=G[n][i];
if(Depth[to]) continue;
Depth[to]=Depth[n]+1;
Father[to]=n;
dfs1(to);
SonAmount[n]+=SonAmount[to];
if(SonAmount[to]>SonAmount[Hson[n]]) Hson[n]=to;
//如果to的树节点数目比目前n的重儿子多 那么to是n的重儿子
}
return;
}
第二次DFS时获取的信息有DFS序号,新序号下的点权(边权),重链链首。注意到每个节点时,先DFS重儿子,这样如果有一条由许多重儿子构成的重链,那么他们的dfs序号一定是连续的,重链头也就是depth最小的那个节点。保证了线段树区间更新。
void dfs2(int n, int prev)
{
Dfsnum[n]=++dfscount;//dfs序号 建线段树用
TreeValue[dfscount]=Val[n];//为线段树保存点权
TopOfHeavyChain[n]=prev;//重链头
if(!Hson[n]) return;
dfs2(Hson[n], prev);
for(int i=0;i<G[n].size();i++)//dfs轻儿子
{
int to=G[n][i];
if(to==Hson[n]||to==Father[n]) continue;
dfs2(to, to);
}
}
最后查询时,查询两个节点ab,如果不在同一条重链上,那么往上跳,跳的方法就是不断查询a到fa=TopOfHeavyChain[a],以及b和fb=TopOfHeavyChain[b],a=father[fa],b=father[fb],到一条重链后最后查询一次这条重链。就结束了。详见代码,说不太清。
代码
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<string>
#include<cstring>
#include<vector>
#include<cmath>
#include<queue>
#include<stack>
#define _ ios_base::sync_with_stdio(0);cin.tie(0);
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
using namespace std;
const int MAXN=100100;
const int oo=0x3f3f3f3f;
typedef long long LL;
const LL loo=4223372036854775807ll;
vector<int> G[MAXN];
int Val[MAXN], Hson[MAXN], SonAmount[MAXN], Father[MAXN], Depth[MAXN];
int Dfsnum[MAXN], TreeValue[MAXN], TopOfHeavyChain[MAXN];
int dfscount;
void AddEdge(int from, int to)
{
G[from].push_back(to);
G[to].push_back(from);
}
void dfs1(int n)
{
SonAmount[n]=1;
for(int i=0;i<G[n].size();i++)
{
int to=G[n][i];
if(Depth[to]) continue;
Depth[to]=Depth[n]+1;
Father[to]=n;
dfs1(to);
SonAmount[n]+=SonAmount[to];
if(SonAmount[to]>SonAmount[Hson[n]]) Hson[n]=to;
//如果to的树节点数目比目前n的重儿子多 那么to是n的重儿子
}
return;
}
void dfs2(int n, int prev)
{
Dfsnum[n]=++dfscount;//dfs序号 建线段树用
TreeValue[dfscount]=Val[n];//为线段树保存点权
TopOfHeavyChain[n]=prev;//重链头
if(!Hson[n]) return;
dfs2(Hson[n], prev);
for(int i=0;i<G[n].size();i++)
{
int to=G[n][i];
if(to==Hson[n]||to==Father[n]) continue;
dfs2(to, to);
}
}
struct STREE
{
LL MPrefix, MPostfix, Sum, MaxValue;
//STREE(LL x) { MPostfix=MPrefix=Sum=MaxValue=x; }
STREE() { MPostfix=MPrefix=Sum=0;MaxValue=-loo; }
STREE(LL l, LL r, LL s, LL m) { MPrefix=l;MPostfix=r; Sum=s;MaxValue=m; }
STREE operator + (const STREE &a)const
{
STREE New;
New.MPrefix=max(MPrefix, Sum+a.MPrefix);
New.MPostfix=max(a.MPostfix, a.Sum+MPostfix);
New.Sum=Sum+a.Sum;
New.MaxValue=max(a.MaxValue, max(MaxValue, MPostfix+a.MPrefix));
return New;
}
}Stree[MAXN<<2];
void build(int l, int r, int rt)
{
if(l==r)
{
Stree[rt].MaxValue=Stree[rt].MPostfix=Stree[rt].MPrefix=Stree[rt].Sum=TreeValue[l];
return;
}
int m=(l+r)>>1;
build(lson);
build(rson);
Stree[rt]=Stree[rt<<1]+Stree[rt<<1|1];
return;
}
STREE query(int L, int R, int l, int r, int rt)
{
if(L<=l&&r<=R) return Stree[rt];
int m=(l+r)>>1;
if(m< L) return query(L, R, rson);
if(m>=R) return query(L, R, lson);
return (query(L, R, lson)+query(L, R, rson));
}
LL solve(int a, int b,int n)
{
STREE lc, rc;
int fa=TopOfHeavyChain[a], fb=TopOfHeavyChain[b];
while(fa!=fb)
{
if(Depth[fa]>Depth[fb])
{
lc=query(Dfsnum[fa], Dfsnum[a], 1, n, 1)+lc;
a=Father[fa];
fa=TopOfHeavyChain[a];
}
else
{
rc=query(Dfsnum[fb], Dfsnum[b], 1, n, 1)+rc;
b=Father[fb];
fb=TopOfHeavyChain[b];
}
}
if(Depth[a]>Depth[b])
{
lc=query(Dfsnum[b], Dfsnum[a], 1, n, 1)+lc;
}
else
{
rc=query(Dfsnum[a], Dfsnum[b], 1, n, 1)+rc;
}
swap(lc.MPostfix, lc.MPrefix);
return ((lc+rc).MaxValue);
}
int main()
{
_
int n;cin>>n;
for(int i=1;i<n;i++)
{
int ta, tb;
cin>>ta>>tb;
AddEdge(ta, tb);
}
for(int i=1;i<=n;i++) cin>>Val[i];
Depth[1]=1;dfs1(1);dfs2(1, 1);
build(1, n, 1);
int m;
cin>>m;
while(m--)
{
int ta, tb;
cin>>ta>>tb;
cout<<solve(ta, tb, n)<<endl;
}
//system("pause");
return 0;
}