题意:给一棵树,每个节点上有权值;m次询问,每次询问两点间最短距离经过的点中,权值>=a&&权值<=b的所有点的权值之和
思路:赛中只知道是树剖+线段树,但是不知道如何处理这个区间ab,赛后才知道可以在线段树的节点中存下最大值Max和最小值Min,假如我们query到某个线段树节点x,如果该节点表示的线段中的Max < a || Min > b,那么表示这条线段中没有点在ab中,所以直接返回;如果Max <= b && Min >= a,直接答案加上该线段权值之和。
虽然把这道题过了,但是这种写法应该会被卡;比如有人提出的1、2、1、2、1、2、1、2…、1、2,假如每次我们都查询最长长度,并且a = 1 , b = 1,那么每次都会走到最后一层,应该会T吧
代码随长,但并不难写(线段树模板+树剖模板),自己定义数组时将存边的数组开成了节点数,TLE了一天,长记性了
#include<bits/stdc++.h>
using namespace std;
#define inf 0x3f3f3f3f
#define IO ios::sync_with_stdio(false)
#define bug cout << "-----\n"
typedef long long ll;
int Mod = 1000000007;
const int N = 100010;
const int M = N * 2;//开空间!开空间!开空间!
int e[M],h[N],ne[M],idx;
int siz[N],f[N],d[N],son[N],top[N],id[N],cnt;
ll ans,a[N],w[N];
struct Tree {
ll l,r,Max,Min,num;
}tree[N << 2];
void build(int x,int l,int r) {
tree[x].l = l;tree[x].r = r;
if(l == r) {
tree[x].num = w[l];
tree[x].Max = w[l];
tree[x].Min = w[l];
return ;
}
int mid = l + r >> 1;
build(x << 1 , l , mid);
build(x << 1 | 1 , mid + 1 , r);
tree[x].num = tree[x << 1].num + tree[x << 1 | 1].num;
tree[x].Max = max(tree[x << 1].Max , tree[x << 1 | 1].Max);
tree[x].Min = min(tree[x << 1].Min , tree[x << 1 | 1].Min);
}
void find(int x,int l,int r,int st,int ed) {
if(tree[x].Max < st || tree[x].Min > ed)return ;
if(tree[x].l >= l && tree[x].r <= r && tree[x].Max <= ed && tree[x].Min >= st) {
ans += tree[x].num;
return ;
}
if(r <= tree[x << 1].r)find(x << 1 , l , r , st , ed);
else if(l >= tree[x << 1 | 1].l)find(x << 1 | 1 , l , r , st , ed);
else {
find(x << 1 , l , tree[x << 1].r , st , ed);
find(x << 1 | 1 , tree[x << 1 | 1].l , r , st , ed);
}
}
void add(int x,int y) {
e[idx] = y;ne[idx] = h[x];h[x] = idx ++;
}
void dfs1(int x,int fa,int deap) {
siz[x] = 1;
f[x] = fa;
d[x] = deap;
int bigson = -1;
for(int i = h[x] ; ~i ; i = ne[i]) {
int t = e[i];
if(t == fa)continue;
dfs1(t , x , deap + 1);
siz[x] += siz[t];
if(siz[t] > bigson)bigson = siz[t],son[x] = t;
}
}
void dfs2(int x,int tp) {
top[x] = tp;
id[x] = ++ cnt;
w[cnt] = a[x];
if(!son[x])return ;
dfs2(son[x] , tp);
for(int i = h[x] ; ~i ; i = ne[i]) {
int t = e[i];
if(t == son[x] || t == f[x])continue;
dfs2(t , t);
}
}
ll query(int x,int y,int st,int ed) {
ans = 0;
while(top[x] != top[y]) {
if(d[top[x]] < d[top[y]])swap(x , y);
find(1 , id[top[x]] , id[x] , st , ed);
x = f[top[x]];
}
if(d[x] > d[y])swap(x , y);
find(1 , id[x] , id[y] , st , ed);
return ans;
}
int main() {
IO;
int n , m;
while(cin >> n >> m) {
memset(h , -1 , sizeof h);
memset(son , 0 , sizeof son);
idx = 0;cnt = 0;
int x , y , st , ed;
for(int i = 1 ; i <= n ; i ++)
cin >> a[i];
for(int i = 1 ; i < n ; i ++) {
cin >> x >> y;
add(x , y);add(y , x);
}
dfs1(1 , 0 , 1);
dfs2(1 , 1);
build(1 , 1 , n);
while(m --) {
cin >> x >> y >> st >> ed;
cout << query(x , y , st , ed);
if(m)cout << ' ';
}
cout << '\n';
}
return 0;
}