传送门
题目大意: 一颗树, 树上有点权, 每次询问u, v的树上路径上的点权范围在[a, b]内的权值和是多少.
思路: 这道题居然暴力能过, 而且巨快…. 后面还是写了写正解. 很明显我们需要把u - > v的路径剖下来, 然后问题就变成区间问题了, 询问区间中在一定范围内的数的和, 而我们知道主席树可以做区间中小于等于某个数的和是多少, 所以这个就可以用主席树来维护, 同时维护下<=a-1 和 <= b的然后减一减就是答案了…
AC Code: 树剖 + 主席树
const int maxn = 1e5+5;
int n, m;
struct Tree {
int ls, rs; ll val; // 左右儿子的编号, 和维护的一个值.
}tre[maxn*40];
int idx, root[maxn];
int build(int l, int r) {
int nod = ++idx;
tre[nod].val = 0;
if (l == r) return nod;
int mid = (l + r) >> 1;
tre[nod].ls = build(l, mid);
tre[nod].rs = build(mid+1, r);
return nod;
}
int update(int pre, int l, int r, int pos, int v) {
int nod = ++idx;
tre[nod] = tre[pre]; tre[nod].val += v;
if (l == r) return nod;
int mid = (l + r) >> 1;
if (pos <= mid) tre[nod].ls = update(tre[pre].ls, l, mid, pos, v);
else tre[nod].rs = update(tre[pre].rs, mid+1, r, pos, v);
return nod;
}
ll query_sum(int ql, int qr, int l, int r, int pos) {
if (l == r) return tre[qr].val - tre[ql].val;
int mid = (l + r) >> 1;
ll num = tre[tre[qr].ls].val - tre[tre[ql].ls].val;
if (pos > mid) {
return num + query_sum(tre[ql].rs, tre[qr].rs, mid+1, r, pos);
}
else return query_sum(tre[ql].ls, tre[qr].ls, l, mid, pos);
}
int a[maxn], len;
vector<int>ve;
int getid(int x) {
return lower_bound(ve.begin(), ve.end(), x) - ve.begin() + 1;
}
int cnt, head[maxn], tim;
int siz[maxn], top[maxn], tid[maxn], pos[maxn];
int son[maxn], dep[maxn], fa[maxn];
struct node {
int to, next, w;
}e[maxn<<1];
void add(int u, int v, int w) {
e[cnt] = node{v, head[u], w};
head[u] = cnt++;
}
void init() {
cnt = 0; Fill(head, -1);
tim = 0; Fill(son, -1); idx = 0;
}
void dfs1(int u, int f, int deep) {
dep[u] = deep + 1; siz[u] = 1;
for (int i = head[u] ; ~i ; i = e[i].next) {
int to = e[i].to;
if (to == f) continue;
fa[to] = u;
dfs1(to, u, deep+1);
siz[u] += siz[to];
if (son[u] == -1 || siz[to] > siz[son[u]]) {
son[u] = to;
}
}
}
void dfs2(int u, int tp) {
top[u] = tp;
tid[u] = ++tim;
pos[tim] = u;
if (son[u] == -1) return ;
dfs2(son[u], tp);
for (int i = head[u] ; ~i ; i = e[i].next) {
int to = e[i].to;
if (to != son[u] && to != fa[u]) {
dfs2(to, to);
}
}
}
ll get_sum(int x, int y, int a, int b) {
ll ans1 = 0, ans2 = 0; --a;
int p1 = upper_bound(ve.begin(), ve.end(), a) - ve.begin();
int p2 = upper_bound(ve.begin(), ve.end(), b) - ve.begin();
for (;top[x] != top[y] ; x = fa[top[x]]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
if (p1) ans1 += query_sum(root[tid[top[x]]-1], root[tid[x]], 1, len, p1);
if (p2) ans2 += query_sum(root[tid[top[x]]-1], root[tid[x]], 1, len, p2);
}
if (dep[x] > dep[y]) swap(x, y);
if (p1) ans1 += query_sum(root[tid[x]-1], root[tid[y]], 1, len, p1);
if (p2) ans2 += query_sum(root[tid[x]-1], root[tid[y]], 1, len, p2);
return ans2 - ans1;
}
void solve() {
while(~scanf("%d%d", &n, &m)) {
for (int i = 1 ; i <= n ; i ++) {
scanf("%d", a+i);
ve.pb(a[i]);
} init();
for (int i = 1 ; i < n ; i ++) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v, 1); add(v, u, 1);
}
dfs1(1, -1, 0); dfs2(1, 1);
sort(ve.begin(), ve.end());
ve.erase(unique(ve.begin(), ve.end()), ve.end());
len = sz(ve);
root[0] = build(1, len);
for (int i = 1 ; i <= n ; i ++) {
root[i] = update(root[i-1], 1, len, getid(a[pos[i]]), a[pos[i]]);
}
for (int i = 1 ; i <= m ; i ++) {
int s, t, a, b;
scanf("%d%d%d%d", &s, &t, &a, &b);
printf("%lld%c", get_sum(s, t, a, b), i == m ? '\n':' ');
}
}
}
LCA暴力代码: (居然比正解快了1s….) 而且代码巨好写啊,,, 所以这种题目如果不想写那么麻烦, 可以试试暴力哦,,
const int maxn = 1e5 + 5;
int n, q;
int fa[maxn], a[maxn];
int head[maxn], cnt, deep[maxn];
int L, R;
struct node {
int to, next, w;
}e[maxn<<1];
void init() {
Fill(head,-1); cnt = 0;
}
void add(int u, int v, int w) {
e[cnt] = node{v, head[u], w};
head[u] = cnt++;
}
void dfs(int u, int f, int d) {
deep[u] = d + 1;
for(int i = head[u] ; ~i ; i = e[i].next) {
int to = e[i].to;
if(to == f) continue;
fa[to] = u;
dfs(to, u, d+1);
}
}
int cal(int x) {
if (x >= L && x <= R) return x;
return 0;
}
ll Find_lca(int u, int v) {
ll ans = 0;
if (deep[u] < deep[v]) swap(u, v);
while(deep[u] != deep[v]) {
ans += cal(a[u]);
u = fa[u];
}
if (u == v) return ans + cal(a[u]);
while(u != v) {
ans += cal(a[u]) + cal(a[v]);
u = fa[u];
v = fa[v];
}
return ans + cal(a[u]);
}
void solve() {
while(~scanf("%d%d", &n, &q)) {
init();
for (int i = 1 ; i <= n ; i ++) scanf("%d", a+i);
for (int i = 1 ; i < n ; i ++) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v, 1); add(v, u, 1);
}
dfs(1, -1, 0);
for (int i = 1 ; i <= q ; i ++) {
int s, t;
scanf("%d%d%d%d", &s, &t, &L, &R);
printf("%lld%c", Find_lca(s, t), i == q ? '\n':' ');
}
}
}