题目链接
题意
给定一棵
n
个节点的树,每个点上有权值。
法一:dfs序+树状数组+离散化
思路
将一条链拆成四条从某个结点到根节点的链,即转化为问 根节点到某个结点的链上满足权值 a≤val≤b 的点的权值和。
离线处理,将从链中拆出来的四个点 u,v,lca(u,v),fa(lca(u,v)) 以及询问的上下界等信息存到对应节点上。
考虑
dfs
,到
u
时将
Code
#include <bits/stdc++.h>
#define maxn 100010
using namespace std;
typedef long long LL;
struct node {
LL a, b;
int pos; bool type;
node(LL _a=0, LL _b=0, int _p=0, bool _t=0) : a(_a), b(_b), pos(_p), type(_t) {}
};
struct Edge {
int to, ne;
Edge(int _to=0, int _ne=0) : to(_to), ne(_ne) {}
}edge[maxn *2];
int dep[maxn], fa[maxn], son[maxn], ne[maxn], sz[maxn], top[maxn], q, tot, n, m;
LL val[maxn], g[maxn], ret[maxn], c[maxn];
vector<node> que[maxn];
vector<LL> ans[maxn];
void addEdge(int u, int v) {
edge[tot] = Edge(v, ne[u]);
ne[u] = tot++;
}
void dfs1(int u, int f, int depth) {
dep[u] = depth, fa[u] = f, son[u] = -1, sz[u] = 1;
for (int i = ne[u]; ~i; i = edge[i].ne) {
int v = edge[i].to;
if (v == f) continue;
dfs1(v, u, depth+1);
sz[u] += sz[v];
if (son[u] == -1 || sz[son[u]] < sz[v]) son[u] = v;
}
}
void dfs2(int u, int sp) {
top[u] = sp;
if (son[u] == -1) return;
dfs2(son[u], sp);
for (int i = ne[u]; ~i; i = edge[i].ne) {
int v = edge[i].to;
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
int LCA(int u, int v) {
int fau = top[u], fav = top[v];
while (fau != fav) {
if (dep[fau] < dep[fav]) swap(fau, fav), swap(u, v);
u = fa[fau], fau = top[u];
}
return dep[u] < dep[v] ? u : v;
}
int lowbit(int x) { return x & (-x); }
void add(int x, LL w) { while (x <= q) c[x] += w, x += lowbit(x); }
LL query(int x) { LL ret = 0; while (x) ret += c[x], x -= lowbit(x); return ret; }
LL calc(LL a, LL b) {
LL ret = 0;
int p = upper_bound(g+1, g+1+q, b) - (g+1);
if (p) ret += query(p);
p = upper_bound(g+1, g+1+q, a-1) - (g+1);
if (p) ret -= query(p);
return ret;
}
void dfs(int u) {
vector<int> v;
add(upper_bound(g+1, g+1+q, val[u])-(g+1), val[u]);
if (!que[u].empty()) {
for (auto x : que[u]) {
ans[u].push_back(calc(x.a, x.b));
}
}
for (int i = ne[u]; ~i; i = edge[i].ne) {
int v = edge[i].to;
if (v == fa[u]) continue;
dfs(v);
}
add(upper_bound(g+1, g+1+q, val[u])-(g+1), -val[u]);
}
void work() {
for (int i = 1; i <= n; ++i) ans[i].clear(), que[i].clear();
tot = 0; memset(ne, -1, sizeof(ne));
memset(c, 0, sizeof c);
for (int i = 1; i <= n; ++i) {
scanf("%lld", &val[i]);
g[i] = val[i];
}
sort(g+1, g+1+n);
q = unique(g+1, g+1+n) - (g+1);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
addEdge(u, v); addEdge(v, u);
}
dfs1(1, -1, 0);
dfs2(1, 1);
for (int i = 0; i < m; ++i) {
int s, t; LL a, b;
scanf("%d%d%lld%lld", &s, &t, &a, &b);
int lca = LCA(s, t);
que[s].push_back(node(a, b, i, 1)), que[t].push_back(node(a, b, i, 1)), que[lca].push_back(node(a, b, i, 0));
if (lca != 1) que[fa[lca]].push_back(node(a, b, i, 0));
}
dfs(1);
memset(ret, 0, sizeof ret);
for (int i = 1; i <= n; ++i) {
if (que[i].empty()) continue;
int sz = que[i].size();
for (int j = 0; j < sz; ++j) {
int p = que[i][j].pos;
if (que[i][j].type) ret[p] += ans[i][j];
else ret[p] -= ans[i][j];
}
}
printf("%lld", ret[0]);
for (int i = 1; i < m; ++i) printf(" %lld", ret[i]); printf("\n");
}
int main() {
while (scanf("%d%d", &n, &m) != EOF) work();
return 0;
}
法二:树链剖分+线段树
思路
离线操作。
询问 [a,b] 范围内的数的和,等价于问 [0,b] 与 [0,a−1] . 再转化一下,如果将树上的点一个个从小到大插入线段树,并且在恰当的时机(已经插入的最大值 ≤ 当前询问的值 && 当前询问的值 < 下一个要插入的值)询问,那么得到的即为权值范围在 [0,当前询问的值] 的区间和。
于是很显然的,就是将 树上所有点的权值 及 询问中涉及到的权值 一起排个序,进行两种操作:1. 插入点; 2. 询问区间和。就是裸的树链剖分了。
Code
#include <bits/stdc++.h>
#define maxn 100010
#define lson (rt << 1)
#define rson lson | 1
using namespace std;
typedef long long LL;
struct node {
LL val; int s, t;
int pos; bool type; bool op;
node(bool _op=0, LL _val=0, int _s=0, int _t=0, int _pos=0, bool _type=0) :
op(_op), val(_val), s(_s), t(_t), pos(_pos), type(_type) {}
};
struct Tree {
int l, r; LL w;
}tr[maxn * 4];
struct Edge {
int to, ne;
Edge(int _to=0, int _ne=0) : to(_to), ne(_ne) {}
}edge[maxn *2];
int dep[maxn], fa[maxn], son[maxn], ne[maxn], sz[maxn], top[maxn], le[maxn], cnt, tot, n, m;
LL val[maxn], ans[maxn];
vector<node> op;
void addEdge(int u, int v) {
edge[tot] = Edge(v, ne[u]);
ne[u] = tot++;
}
void dfs1(int u, int f, int depth) {
dep[u] = depth, fa[u] = f, son[u] = -1, sz[u] = 1;
for (int i = ne[u]; ~i; i = edge[i].ne) {
int v = edge[i].to;
if (v == f) continue;
dfs1(v, u, depth+1);
sz[u] += sz[v];
if (son[u] == -1 || sz[son[u]] < sz[v]) son[u] = v;
}
}
void dfs2(int u, int sp) {
top[u] = sp, le[u] = ++cnt;
if (son[u] == -1) return;
dfs2(son[u], sp);
for (int i = ne[u]; ~i; i = edge[i].ne) {
int v = edge[i].to;
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
bool cmp(node u, node v) { return u.val < v.val || (u.val == v.val && u.op < v.op); }
void build(int rt, int l, int r) {
tr[rt].l = l, tr[rt].r = r, tr[rt].w = 0;
if (l == r) return;
int mid = l+r >> 1;
build(lson, l, mid); build(rson, mid+1, r);
}
void modify(int rt, LL x, int p) {
if (tr[rt].l == tr[rt].r) { tr[rt].w = x; return; }
int mid = tr[rt].l+tr[rt].r >> 1;
if (p <= mid) modify(lson, x, p);
else modify(rson, x, p);
tr[rt].w = tr[lson].w + tr[rson].w;
}
LL query(int rt, int l, int r) {
if (tr[rt].l == l && tr[rt].r == r) return tr[rt].w;
int mid = tr[rt].l + tr[rt].r >> 1;
if (r <= mid) return query(lson, l, r);
else if (l > mid) return query(rson, l, r);
else return query(lson, l, mid) + query(rson, mid+1, r);
}
LL ask(int u, int v) {
int fau = top[u], fav = top[v]; LL ret = 0;
while (fau != fav) {
if (dep[fau] < dep[fav]) swap(fau, fav), swap(u, v);
ret += query(1, le[fau], le[u]);
u = fa[fau], fau = top[u];
}
if (dep[u] < dep[v]) swap(u, v);
ret += query(1, le[v], le[u]);
return ret;
}
void work() {
tot = cnt = 0; memset(ne, -1, sizeof(ne));
op.clear();
for (int i = 1; i <= n; ++i) {
scanf("%lld", &val[i]);
op.push_back(node(0, val[i], i));
}
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
addEdge(u, v); addEdge(v, u);
}
dfs1(1, -1, 0);
dfs2(1, 1);
build(1, 1, n);
for (int i = 0; i < m; ++i) {
int s, t; LL a, b;
scanf("%d%d%lld%lld", &s, &t, &a, &b);
op.push_back(node(1, a-1, s, t, i, 0)); op.push_back(node(1, b, s, t, i, 1));
}
sort(op.begin(), op.end(), cmp);
memset(ans, 0, sizeof(ans));
for (auto nd : op) {
if (nd.op) {
LL ret = ask(nd.s, nd.t);
if (nd.type) ans[nd.pos] += ret;
else ans[nd.pos] -= ret;
}
else modify(1, nd.val, le[nd.s]);
}
printf("%lld", ans[0]);
for (int i = 1; i < m; ++i) printf(" %lld", ans[i]); printf("\n");
}
int main() {
while (scanf("%d%d", &n, &m) != EOF) work();
return 0;
}
两种方法跑下来时间差不多0 0
以及w感觉很开心没怎么调就1A过了