题目来源:https://www.luogu.org/problem/P1505
★这题debug了一下午,最后终于AC了~
思路:
这题就是典型的 树链剖分+线段树 的题目,用 线段树维护 区间和、区间最大值以及区间最小值
这题比较坑的就是没有给数据范围,不过我帮你们试过水了,n不超过1e5,这不是最坑的。
很明显这些桥就是 有n个节点的树的 边,这些桥的愉悦度就是边权。有没有办法用点权记录他们呢? 不难发现,每一个非根节点 都有且仅有一个 父节点 ,我们就用每一条边的 子节点 来记录那些边权就好了。
这题 处理正负翻转 的方式是 最大的值A和最小的值B翻转后 最大的值肯定是-B 最小的值肯定是-A 用一个lazy_tag 标记这个区间的更小的区间是否翻转即可,0表示不反转 1表示反转。
最后就是注意一下 update_interval2(1,1,cnt,id[a]+1,id[b]);
这个地方要 id[ a ] +1 的原因:
举个例子,现在要求5 到 8 的区间和 ,我们模拟一遍不加1的过程
先计算3-5 重链的边权和(其实计算了5-3和3-2两条边) 然后5变成了2 2和8的链头都是1 直接跳出while循环。
接着 计算2-6-7-8 重链的边权和 (其实计算了1-2 2-6 6-7 7-8的边权和) 发现多加了一条1-2的边 ,所以我们要+1
如下图(红色为重链)
代码:
#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<string>
#define ls k<<1
#define rs k<<1|1
using namespace std;
const int maxn=1e5+5;
const int sz=1<<10;
const int mod=1e9+7;
const int inf=2e9;
const double eps=1e-8;
const double pi=acos(-1);
typedef long long LL;
int n,m,cnt;
int head[maxn],nex[maxn<<1],to[maxn<<1],w[maxn<<1];
int dep[maxn],siz[maxn],son[maxn],fa[maxn],ww[maxn];
int top[maxn],id[maxn],wt[maxn];
int sum[maxn<<2],mini[maxn<<2],maxx[maxn<<2],lazy[maxn<<2];
template<class t>
inline void read(t &x) //快读
{
char c; x=1;
while((c=getchar())<'0'||c>'9') if(c=='-') x=-1;
t res=c-'0';
while((c=getchar())>='0'&&c<='9') res=res*10+c-'0';
x*=res;
}
inline void add_edge(int x,int y,int v) //链式前向星
{
to[++cnt]=y;
w[cnt]=v;
nex[cnt]=head[x];
head[x]=cnt;
}
void dfs1(int x,int f,int d)
{
dep[x]=d;
fa[x]=f;
siz[x]=1;
for(int i=head[x];i;i=nex[i]){
int y=to[i];
if(y==f) continue;
ww[y]=w[i];
dfs1(y,x,d+1);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]]) son[x]=y;
}
}
void dfs2(int x,int topf)
{
top[x]=topf;
wt[++cnt]=ww[x];
id[x]=cnt;
if(son[x]==0) return ;
dfs2(son[x],topf);
for(int i=head[x];i;i=nex[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
inline void push_up(int k)
{
maxx[k]=max(maxx[ls],maxx[rs]);
mini[k]=min(mini[ls],mini[rs]);
sum[k]=sum[ls]+sum[rs];
}
inline void push_down(int k,int l,int r,int mid)
{
if(lazy[k]){
mini[ls]=-mini[ls]; mini[rs]=-mini[rs];
maxx[ls]=-maxx[ls]; maxx[rs]=-maxx[rs];
sum[ls]=-sum[ls];sum[rs]=-sum[rs];
swap(mini[ls],maxx[ls]); swap(mini[rs],maxx[rs]);
lazy[ls]^=1; lazy[rs]^=1; lazy[k]=0;
}
}
void build(int k,int l,int r)
{
if(l==r){mini[k]=maxx[k]=sum[k]=wt[l]; return ; }
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
push_up(k);
}
void update_point(int k,int l,int r,int pos,int v) //单点更新
{
if(l==r){sum[k]=mini[k]=maxx[k]=v; return ; }
int mid=l+r>>1;
push_down(k,l,r,mid);
if(mid>=pos) update_point(ls,l,mid,pos,v);
else update_point(rs,mid+1,r,pos,v);
push_up(k);
}
void update_interval2(int k,int l,int r,int x,int y) //区间更新
{
if(x<=l&&r<=y){
lazy[k]^=1; sum[k]=-sum[k];
mini[k]=-mini[k]; maxx[k]=-maxx[k];
swap(mini[k],maxx[k]);
return ;
}
int mid=l+r>>1;
push_down(k,l,r,mid);
if(mid>=x) update_interval2(ls,l,mid,x,y);
if(mid<y) update_interval2(rs,mid+1,r,x,y);
push_up(k);
}
int query_sum2(int k,int l,int r,int x,int y)
{
if(x<=l&&r<=y) return sum[k];
int mid=l+r>>1,ans=0;
push_down(k,l,r,mid);
if(mid>=x) ans+=query_sum2(ls,l,mid,x,y);
if(mid<y) ans+=query_sum2(rs,mid+1,r,x,y);
return ans;
}
int query_max2(int k,int l,int r,int x,int y)
{
if(x<=l&&r<=y) return maxx[k];
int mid=l+r>>1,ans=-inf;
push_down(k,l,r,mid);
if(mid>=x) ans=max(ans,query_max2(ls,l,mid,x,y));
if(mid<y) ans=max(ans,query_max2(rs,mid+1,r,x,y));
return ans;
}
int query_min2(int k,int l,int r,int x,int y)
{
if(x<=l&&r<=y) return mini[k];
int mid=l+r>>1;
push_down(k,l,r,mid);
int ans=inf;
if(mid>=x) ans=min(ans,query_min2(ls,l,mid,x,y));
if(mid<y) ans=min(ans,query_min2(rs,mid+1,r,x,y));
return ans;
}
void update_interval1(int a,int b)
{
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]]) swap(a,b);
update_interval2(1,1,cnt,id[top[a]],id[a]);
a=fa[top[a]];
}
if(dep[a]>dep[b]) swap(a,b);
update_interval2(1,1,cnt,id[a]+1,id[b]);
}
int query_sum1(int a,int b)
{
int ans=0;
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]]) swap(a,b);
ans+=query_sum2(1,1,cnt,id[top[a]],id[a]);
a=fa[top[a]];
}
if(dep[a]>dep[b]) swap(a,b);
return ans+query_sum2(1,1,cnt,id[a]+1,id[b]);
}
int query_max1(int a,int b)
{
int ans=-inf;
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]]) swap(a,b);
ans=max(ans,query_max2(1,1,cnt,id[top[a]],id[a]));
a=fa[top[a]];
}
if(dep[a]>dep[b]) swap(a,b);
return max(ans,query_max2(1,1,cnt,id[a]+1,id[b]));
}
int query_min1(int a,int b)
{
int ans=inf;
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]]) swap(a,b);
ans=min(ans,query_min2(1,1,cnt,id[top[a]],id[a]));
a=fa[top[a]];
}
if(dep[a]>dep[b]) swap(a,b);
return min(ans,query_min2(1,1,cnt,id[a]+1,id[b]));
}
int main()
{
read(n); cnt=0;
for(int i=1;i<n;i++){
int a,b,c;
read(a); read(b); read(c);
add_edge(a+1,b+1,c); add_edge(b+1,a+1,c);
}
cnt=0;
dfs1(1,0,1); dfs2(1,1); build(1,1,cnt);
read(m);
while(m--){
char op[10]; int a,b;
scanf("%s %d %d",op,&a,&b);
if(op[0]=='C') update_point(1,1,cnt,id[a+1],b);
else if(op[0]=='N') update_interval1(a+1,b+1);
else if(op[0]=='S') printf("%d\n",query_sum1(a+1,b+1));
else if(op[1]=='A') printf("%d\n",query_max1(a+1,b+1));
else if(op[1]=='I') printf("%d\n",query_min1(a+1,b+1));
// for(int i=1;i<=10;i++) cout<<sum[i]<<' ';
// cout<<endl;
}
return 0;
}
//0 2
//0 1 2 2(2)
//0 0 1 1(1)