做了4个小时,思考一小时,打代码20分钟,调试2个半小时
没看题解,自己意淫的做法,习惯数组版的线段树,所以没办法像结构体一样返回多个参数,故用引用返回
需要的知识:树剖,最近公共祖先,最大子段和
变量:mx:最大子段和,lg:区间左起的最大和,rg:区间右起的最大和,sum:区间和,tag:区间修改标记
一.树剖,基本操作
void dfs1(int x,int pre){
fa[x]=pre;
dep[x]=dep[pre]+1;
sz[x]=1;
int mxson=-1;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==pre)continue;
dfs1(y,x);
sz[x]+=sz[y];
if(sz[y]>mxson)mxson=sz[y],son[x]=y;
}
}
//第二个dfs
void dfs2(int x,int topf){
top[x]=topf;
id[x]=++cnt;
w[id[x]]=a[x];//把点重新编号作为线段树中的编号
if(son[x])dfs2(son[x],topf);
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}
int lca(int x,int y){//求最近公共祖先
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
return x;
}
二.build函数和pushup
tag[]是区间修改标记,因为可以修改为0,所以tag初始化为-inf
void push_up(int k){
mx[k]=max(mx[ls],max(mx[rs],rg[ls]+lg[rs]));
lg[k]=max(lg[ls],sum[ls]+lg[rs]);
rg[k]=max(rg[rs],sum[rs]+rg[ls]);
sum[k]=sum[ls]+sum[rs];
}
void build(int l,int r,int k){
tag[k]=-inf;
if(l==r){
sum[k]=lg[k]=rg[k]=mx[k]=w[l];
return;
}
build(l,mid,ls);
build(mid+1,r,rs);
push_up(k);
}
三.下传标记
修改该区间为c,如果c>=0那这样,如果c<0那就那样,很简单
最后记得把tag[k]赋值为-inf
void push_down(int k,int m){
if(tag[k]==-inf)return;//tag可以为0,所以赋值为-inf
int c=tag[k];
tag[ls]=tag[rs]=c;
if(tag[k]>=0){
mx[ls]=sum[ls]=lg[ls]=rg[ls]=(m-m/2)*c;
mx[rs]=sum[rs]=lg[rs]=rg[rs]=(m/2)*c;
}
else{
mx[ls]=lg[ls]=rg[ls]=c;
mx[rs]=lg[rs]=rg[rs]=c;
sum[ls]=(m-m/2)*c;
sum[rs]=(m/2)*c;
}
tag[k]=-inf;
}
四.update 区间修改函数
应该很好懂
void update(int a,int b,int c,int l,int r,int k){
if(a<=l&&b>=r){
int len=r-l+1;
if(c>=0){
mx[k]=lg[k]=rg[k]=sum[k]=len*c;
}
else{
mx[k]=lg[k]=rg[k]=c;
sum[k]=len*c;
}
tag[k]=c;
return;
}
push_down(k,r-l+1);
if(a<=mid)update(a,b,c,l,mid,ls);
if(b>mid)update(a,b,c,mid+1,r,rs);
push_up(k);
}
void update(int x,int y,int c){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(id[top[x]],id[x],c,1,n,1);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(id[x],id[y],c,1,n,1);
}
五.查询树上两点间的最大子段和
我的想法很暴力
首先,求出两点x,y的最近公共祖先lc=lca(x,y)
然后,分别求出x到lc的最大字段和,以及y到lc的最大字段和
最后,合并两个最大子段和
大概思路如上。
考虑如何求x和y到lc的最大子段和呢?
考虑x到lc
1.树链的每条链编号都是连续的,从上往下递增的,所以在线段树中,我们肯定能求出每条树链的mx,lg,rg,sum(一条树链的lg在深度浅部分,rg在深度深部分,因为编号是从上往下连续递增的),如果x到lc有多段不连续区间,那么我们就要合并这每段不连续的区间,合并方法和pushup函数一样
考虑y到lc
1.和x到lc求法一样,lc处注意一下,不能包括lc。
void query(int a,int b,int l,int r,int &MX,int &LG,int &RG,int &SUM,int k){
if(a>b){
MX=LG=RG=-inf;
SUM=0;
return;
}
if(a<=l&&b>=r){
MX=mx[k];
LG=lg[k];
RG=rg[k];
SUM=sum[k];
return;
}
push_down(k,r-l+1);
if(a>mid)query(a,b,mid+1,r,MX,LG,RG,SUM,rs);
else if(b<=mid)query(a,b,l,mid,MX,LG,RG,SUM,ls);
else{
int mxl,lgl,rgl,suml;
int mxr,lgr,rgr,sumr;
query(a,b,l,mid,mxl,lgl,rgl,suml,ls);
query(a,b,mid+1,r,mxr,lgr,rgr,sumr,rs);
MX=max(mxl,max(mxr,rgl+lgr));
LG=max(lgl,suml+lgr);
RG=max(rgr,sumr+rgl);
SUM=suml+sumr;
}
}
//重点,最难的部分
void merge(int &mxl,int &lgl,int &rgl,int &suml,
int &mxr,int &lgr,int &rgr,int &sumr){
mxl=max(mxl,max(mxr,rgl+lgr));
lgl=max(lgl,suml+lgr);
rgl=max(rgr,rgl+sumr);
suml+=sumr;
}
void get(int &MX,int &LG,int &RG,int &SUM,int x,int lc,int rev){
int mxl=0,lgl=0,rgl=0,suml=0;
int mxr=0,lgr=0,rgr=0,sumr=0;
int t=0;
while(top[x]!=top[lc]){
if(t>=1)mxr=mxl,lgr=lgl,rgr=rgl,sumr=suml;
query(id[top[x]],id[x],1,n,mxl,lgl,rgl,suml,1);
x=fa[top[x]];
if(t>=1)merge(mxl,lgl,rgl,suml,mxr,lgr,rgr,sumr);
t++;
}
if(t>=1){
mxr=mxl,lgr=lgl,rgr=rgl,sumr=suml;
query(id[lc]+rev,id[x],1,n,mxl,lgl,rgl,suml,1);
merge(mxl,lgl,rgl,suml,mxr,lgr,rgr,sumr);
}
else{
query(id[lc]+rev,id[x],1,n,mxl,lgl,rgl,suml,1);
}
MX=mxl,LG=lgl,RG=rgl,SUM=suml;
}
int query(int x,int y){
int lc=lca(x,y);
int mxl,lgl,rgl,suml;
int mxr,lgr,rgr,sumr;
get(mxl,lgl,rgl,suml,x,lc,0);
get(mxr,lgr,rgr,sumr,y,lc,1);
int ans=0;
ans=max(mxl,max(mxr,lgl+lgr));
return ans;
}
完整代码
#include <bits/stdc++.h>
#define mid (l+r)/2
#define ls (k<<1)
#define rs (k<<1|1)
#define inf 0x3f3f3f3f
using namespace std;
const int N=4e5+10;
int tot=1;
int head[N],ver[N],nxt[N];
void add(int u,int v){
ver[++tot]=v,nxt[tot]=head[u],head[u]=tot;
}
int n,m;
int a[N],w[N],id[N];
int son[N],sz[N],dep[N],fa[N],top[N];
int cnt;
//树剖
//第一个dfs
void dfs1(int x,int pre){
fa[x]=pre;
dep[x]=dep[pre]+1;
sz[x]=1;
int mxson=-1;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==pre)continue;
dfs1(y,x);
sz[x]+=sz[y];
if(sz[y]>mxson)mxson=sz[y],son[x]=y;
}
}
//第二个dfs
void dfs2(int x,int topf){
top[x]=topf;
id[x]=++cnt;
w[id[x]]=a[x];//把点重新编号作为线段树中的编号
if(son[x])dfs2(son[x],topf);
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}
int lca(int x,int y){//求最近公共祖先
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
return x;
}
int mx[N],lg[N],rg[N],sum[N],tag[N];
void push_up(int k){
mx[k]=max(mx[ls],max(mx[rs],rg[ls]+lg[rs]));
lg[k]=max(lg[ls],sum[ls]+lg[rs]);
rg[k]=max(rg[rs],sum[rs]+rg[ls]);
sum[k]=sum[ls]+sum[rs];
}
void build(int l,int r,int k){
tag[k]=-inf;
if(l==r){
sum[k]=lg[k]=rg[k]=mx[k]=w[l];
return;
}
build(l,mid,ls);
build(mid+1,r,rs);
push_up(k);
}
void push_down(int k,int m){
if(tag[k]==-inf)return;//tag可以为0,所以赋值为-inf
int c=tag[k];
tag[ls]=tag[rs]=c;
if(tag[k]>=0){
mx[ls]=sum[ls]=lg[ls]=rg[ls]=(m-m/2)*c;
mx[rs]=sum[rs]=lg[rs]=rg[rs]=(m/2)*c;
}
else{
mx[ls]=lg[ls]=rg[ls]=c;
mx[rs]=lg[rs]=rg[rs]=c;
sum[ls]=(m-m/2)*c;
sum[rs]=(m/2)*c;
}
tag[k]=-inf;
}
void update(int a,int b,int c,int l,int r,int k){
if(a<=l&&b>=r){
int len=r-l+1;
if(c>=0){
mx[k]=lg[k]=rg[k]=sum[k]=len*c;
}
else{
mx[k]=lg[k]=rg[k]=c;
sum[k]=len*c;
}
tag[k]=c;
return;
}
push_down(k,r-l+1);
if(a<=mid)update(a,b,c,l,mid,ls);
if(b>mid)update(a,b,c,mid+1,r,rs);
push_up(k);
}
void update(int x,int y,int c){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(id[top[x]],id[x],c,1,n,1);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(id[x],id[y],c,1,n,1);
}
void query(int a,int b,int l,int r,int &MX,int &LG,int &RG,int &SUM,int k){
if(a>b){
MX=LG=RG=-inf;
SUM=0;
return;
}
if(a<=l&&b>=r){
MX=mx[k];
LG=lg[k];
RG=rg[k];
SUM=sum[k];
return;
}
push_down(k,r-l+1);
if(a>mid)query(a,b,mid+1,r,MX,LG,RG,SUM,rs);
else if(b<=mid)query(a,b,l,mid,MX,LG,RG,SUM,ls);
else{
int mxl,lgl,rgl,suml;
int mxr,lgr,rgr,sumr;
query(a,b,l,mid,mxl,lgl,rgl,suml,ls);
query(a,b,mid+1,r,mxr,lgr,rgr,sumr,rs);
MX=max(mxl,max(mxr,rgl+lgr));
LG=max(lgl,suml+lgr);
RG=max(rgr,sumr+rgl);
SUM=suml+sumr;
}
}
//重点,最难的部分
void merge(int &mxl,int &lgl,int &rgl,int &suml,
int &mxr,int &lgr,int &rgr,int &sumr){
mxl=max(mxl,max(mxr,rgl+lgr));
lgl=max(lgl,suml+lgr);
rgl=max(rgr,rgl+sumr);
suml+=sumr;
}
void get(int &MX,int &LG,int &RG,int &SUM,int x,int lc,int rev){
int mxl=0,lgl=0,rgl=0,suml=0;
int mxr=0,lgr=0,rgr=0,sumr=0;
int t=0;
while(top[x]!=top[lc]){
if(t>=1)mxr=mxl,lgr=lgl,rgr=rgl,sumr=suml;
query(id[top[x]],id[x],1,n,mxl,lgl,rgl,suml,1);
x=fa[top[x]];
if(t>=1)merge(mxl,lgl,rgl,suml,mxr,lgr,rgr,sumr);
t++;
}
if(t>=1){
mxr=mxl,lgr=lgl,rgr=rgl,sumr=suml;
query(id[lc]+rev,id[x],1,n,mxl,lgl,rgl,suml,1);
merge(mxl,lgl,rgl,suml,mxr,lgr,rgr,sumr);
}
else{
query(id[lc]+rev,id[x],1,n,mxl,lgl,rgl,suml,1);
}
MX=mxl,LG=lgl,RG=rgl,SUM=suml;
}
int query(int x,int y){
int lc=lca(x,y);
int mxl,lgl,rgl,suml;
int mxr,lgr,rgr,sumr;
get(mxl,lgl,rgl,suml,x,lc,0);
get(mxr,lgr,rgr,sumr,y,lc,1);
int ans=0;
ans=max(mxl,max(mxr,lgl+lgr));
return ans;
}
int main(){
cin>>n;
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,1);
build(1,n,1);
cin>>m;
while(m--){
int op,x,y,z;
cin>>op>>x>>y;
if(op==1){
int ans=query(x,y);
if(ans<0)cout<<0;
else cout<<ans;
cout<<endl;
}
else{
cin>>z;
update(x,y,z);
}
}
}
/*
11
-1 -3 7 4 2 -8 -4 7 -1 3 9
1 2
1 3
2 4
2 5
3 6
3 7
4 8
4 9
6 10
6 11
100
*/