题目描述:
给定一颗有n个节点的树
每条边都有一个颜色和长度
现给出q个查询
每个查询先将所有颜色值为c的边的长度改为d,
再查询节点u到节点v的最短距离
每次查询独立,互不影响。
即所有的长度变化只在这一次查询中生效
输入格式:
n q
1 <= n , q <= 1e5
接下来n - 1行
每行 a b c d 表示a节点和b节点有一条长度为d并且颜色为c的边
1 <= a , b <= n
1 <= c <= n - 1
1 <= d <= 1e4
a1 , b1 , c1 , d1
…
an-1 , bn-1 , cn-1, dn-1
接下来q行
每行x , y , u , v
表示如果将所有颜色为x的边的长度修改为y
点u和v之间的距离是多少
x1 , y1 , u1 , v1
…
xq , yq , uq , vq
1 <= x <= n - 1
1 <= y <= 1e4
1 <= u , v <= n
维护树上路径,一般使用树链剖分算法,将树展开成一条链后,还要查询某区间内颜色c的个数,这需要使用可持久化线段树,还需要维护某路径的区间和,使用线段树。
这里使用可持久化线段树来完成。
这里有几个注意点:
1.需要将边权转移到点上,因为树链剖分维护的是点权,这里转移的方法是将每条边的值转移到他的父节点上。
2.在查询的时候,因为是将边权转移到点上,因此,当两点的顶端相同的时候,靠近根的结点值不被计算。
3.可持久化线段树创建时的模拟指针idx需要从1开始
4.new数值中1是没有值的,因为根节点没有父节点
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100200;
const int M=1e9+20;
struct node {
int l,r;
int colornum;
int sumpath;
}tr[N*40];
int idx=1;
int head[N];//可持久化线段树头指针
int n,q;
int h[N],to[N*2],ne[N*2],idx2=1;//图的存储
int depth[N],f[N],sz[N],son[N],top[N],cnt;
int id[N],nw[N],nw2[N];
int w[N],w2[N];
//以上是树链剖分的变量
//w表示长度,w2表示color
int w3[N*2],w4[N*2];//边权
void dfs( int rt,int fa){//将边权转化为点权
for( int i=h[rt];i!=-1;i=ne[i]){
int j=to[i];
if(j==fa) continue;
w2[j]=w3[i];
w[j]=w4[i];
dfs(j,rt);
}
}
void add( int a,int b,int c,int d){
w4[idx2]=d,w3[idx2]=c,to[idx2]=b,ne[idx2]=h[a],h[a]=idx2++;
}
void pushup( int rt){
tr[rt].sumpath=tr[tr[rt].l].sumpath+tr[tr[rt].r].sumpath;
tr[rt].colornum=tr[tr[rt].l].colornum+tr[tr[rt].r].colornum;
}
int create( int l,int r){
int p=idx++;
if(l==r){
tr[p].sumpath=nw[l];
return p;
}
int mid=(l+r)/2;
tr[p].l=create(l,mid);
tr[p].r=create(mid+1,r);
tr[p].sumpath=tr[tr[p].l].sumpath+tr[tr[p].r].sumpath;
return p;
}
int insert(int o,int l,int r,int val,int val_path){
int rt=idx++;
tr[rt]=tr[o];
if(l==r){
tr[rt].colornum+=1;
tr[rt].sumpath+=val_path;
return rt;
}
int mid=(l+r)/2;
if(val<=mid){
tr[rt].l=insert(tr[rt].l,l,mid,val,val_path);
}
else tr[rt].r=insert(tr[rt].r,mid+1,r,val,val_path);
pushup(rt);
return rt;
}
int q_color( int rl,int rr,int l,int r,int L,int R){
if(l>R||r<L) return 0;
if(l==r) return tr[rr].colornum-tr[rl].colornum;
if(l>=L&&r<=R) return tr[rr].colornum-tr[rl].colornum;
int mid=(l+r)/2;
int res=0;
if(mid>=L) res+=q_color(tr[rl].l,tr[rr].l,l,mid,L,R);
if(R>mid) res+=q_color(tr[rl].r,tr[rr].r,mid+1,r,L,R);
return res;
}
int q_path( int rl,int rr,int l,int r,int L,int R){
if(l>R||r<L) return 0;
if(l==r) return tr[rr].sumpath-tr[rl].sumpath;
if(l>=L&&r<=R) return tr[rr].sumpath-tr[rl].sumpath;
int mid=(l+r)/2;
int res=0;
if(mid>=L) res+=q_path(tr[rl].l,tr[rr].l,l,mid,L,R);
if(R>mid) res+=q_path(tr[rl].r,tr[rr].r,mid+1,r,L,R);
return res;
}
//以上是可持久化线段树操作
void dfs1( int rt,int fa){
depth[rt]=depth[fa]+1;f[rt]=fa;
sz[rt]=1;son[rt]=0;
for( int i=h[rt];i!=-1;i=ne[i]){
int j=to[i];
if(j==fa) continue;
else dfs1(j,rt);
sz[rt]+=sz[j];
if(sz[son[rt]]<sz[j])
son[rt]=j;
}
}
void dfs2( int rt,int fa){//先序遍历dfs
id[rt]=++cnt;
nw[id[rt]]=w[rt];
nw2[id[rt]]=w2[rt];
//这里w是点权,不是边权,nw数组是将树展开成链式结构后每个点的权值
if(son[fa]==rt) top[rt]=top[fa];
else top[rt]=rt;
if(!son[rt]) return ;
dfs2(son[rt],rt);//首先遍历重儿子
for( int i=h[rt];i!=-1;i=ne[i]){
int j=to[i];
if(j==fa||j==son[rt]) continue;
else dfs2(j,rt);
}
}
int query_color( int u,int v,int c){
int res=0;
while(top[u] !=top[v] ){//uv不在同一条链上
//保证v在u的上面
if(depth[top[u]]<depth[top[v]]) swap(u,v);
//更新树链
res+=q_color(head[id[top[u]]-1],head[id[u]],1,n,c,c);
u=f[top[u]];
}
if(depth[u]<depth[v]) swap(u,v);
res+=q_color(head[id[v]],head[id[u]],1,n,c,c);
return res;
}
int query_path_val( int u,int v,int c){
int res=0;
while(top[u] !=top[v] ){//uv不在同一条链上
//保证v在u的上面
if(depth[top[u]]<depth[top[v]]) swap(u,v);
//更新树链
res+=q_path(head[id[top[u]]-1],head[id[u]],1,n,c,c);
// cout<<id[top[u]]<<id[u]<<u<<top[u]<<endl;
u=f[top[u]];
}
if(depth[u]<depth[v]) swap(u,v);
res+=q_path(head[id[v]],head[id[u]],1,n,c,c);
return res;
}
int query_path_val_sum( int u,int v){
int res=0;
while(top[u] !=top[v] ){//uv不在同一条链上
//保证v在u的上面
if(depth[top[u]]<depth[top[v]]) swap(u,v);
res+=q_path(0,head[0],1,n,id[top[u]],id[u]);
u=f[top[u]];
}
if(depth[u]<depth[v]) swap(u,v);
res+=q_path(0,head[0],1,n,id[v]+1,id[u]);
return res;
}
int main(){
memset(h,-1,sizeof(h));
cin>>n>>q;
for( int i=1;i<n;i++){
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
add(a,b,c,d);
add(b,a,c,d);
}
dfs(1,0);
dfs1(1,0);
dfs2(1,0);
head[0]=create(1,n);
int cnt_head=1;
for( int i=2;i<=n;i++){
head[i]=insert(head[i-1],1,n,nw2[i],nw[i]);
}
for( int i=1;i<=q;i++){
int x,y,u,v;
scanf("%d%d%d%d",&x,&y,&u,&v);
int res=query_path_val_sum(u,v)-query_path_val(u,v,x)+query_color(u,v,x)*y;
printf("%d\n",res);
}
return 0;
}