题目
解题思路
使用树链剖分和离线。
按照权值大小不断添加边,将边权转化成点权,用线段树维护区间值。
复杂度
O
(
n
l
o
g
2
2
n
)
O(nlog_2^2n)
O(nlog22n)。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef long long ll;
void read(int &x) {
x = 0; int f = 1; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); }
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
x *= f;
}
void write(ll x) {
if (x < 0) {
putchar('-'); write(-x);
return;
}
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
#define lson rt * 2
#define rson rt * 2 + 1
const int N = 1e5 + 100;
int n, q;
int f[N];
struct node {
int l, r, len;
ll val;
node() { l = r = val = 0; len = 1; }
node(int l, int r, int len, ll val) :l(l), r(r), len(len), val(val) {}
}tree[N * 4];
node cal(node a, node b) {
node c;
c.len = a.len + b.len;
c.val = a.val + b.val - f[a.r] - f[b.l] + f[a.r + b.l];
c.l = a.l; if (a.l == a.len) c.l += b.l;
c.r = b.r; if (b.r == b.len) c.r += a.r;
return c;
}
void insert(int id, int l, int r, int rt) {
if (l == r) return void(tree[rt] = node(1, 1, 1, f[1]));
int m = (l + r) / 2;
if (id <= m) insert(id, l, m, lson);
else insert(id, m + 1, r, rson);
tree[rt] = cal(tree[lson], tree[rson]);
}
node query(int rl, int rr, int l, int r, int rt) {
if (rl == l && rr == r) return tree[rt];
int m = (l + r) / 2;
if (rr <= m) return query(rl, rr, l, m, lson);
else if (m < rl) return query(rl, rr, m + 1, r, rson);
return cal(query(rl, m, l, m, lson), query(m + 1, rr, m + 1, r, rson));
}
int tot;
int siz[N], faz[N], dep[N], son[N], tid[N], top[N], rev[N];
vector<int> V[N];
void dfs1(int u, int fa) {
faz[u] = fa;
dep[u] = dep[fa] + 1;
siz[u] = 1;
for (int v : V[u]) {
if (v == fa) continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[son[u]] < siz[v]) son[u] = v;
}
}
void dfs2(int u, int t) {
tid[u] = ++tot;
rev[tot] = u;
top[u] = t;
if (son[u]) dfs2(son[u], t);
for (int v : V[u]) {
if (v == faz[u] || v == son[u]) continue;
dfs2(v, v);
}
}
ll solve(int x, int y) {
int fx = top[x], fy = top[y]; node rx, ry;
while (fx != fy) {
if (dep[fx] > dep[fy]) {
rx = cal(query(tid[fx], tid[x], 1, n, 1), rx);
x = faz[fx]; fx = top[x];
}
else {
ry = cal(query(tid[fy], tid[y], 1, n, 1), ry);
y = faz[fy]; fy = top[y];
}
}
if (x != y) {
if (tid[x] < tid[y]) ry = cal(query(tid[x] + 1, tid[y], 1, n, 1), ry);
else rx = cal(query(tid[y] + 1, tid[x], 1, n, 1), rx);
}
swap(rx.l, rx.r);
return cal(rx, ry).val;
}
struct Qu {
int u, v, w, id;
}e[N], qu[N];
bool operator < (Qu a, Qu b) {
return a.w > b.w;
}
ll ans[N];
int main() {
//freopen("0.txt", "r", stdin);
read(n); read(q);
for (int i = 1; i < n; i++) read(f[i]);
for (int i = 1; i < n; i++) {
read(e[i].u); read(e[i].v); read(e[i].w);
V[e[i].u].push_back(e[i].v);
V[e[i].v].push_back(e[i].u);
}
for (int i = 1; i <= q; i++) read(qu[i].u), read(qu[i].v), read(qu[i].w), qu[i].id = i;
sort(e + 1, e + n);
sort(qu + 1, qu + q + 1);
dfs1(1, 0);
dfs2(1, 1);
for (int i = 1, j = 1; i <= q; i++) {
while (j < n && e[j].w >= qu[i].w) {
int id = e[j].u;
if (dep[e[j].u] < dep[e[j].v]) id = e[j].v;
insert(tid[id], 1, n, 1);
j++;
}
ans[qu[i].id] = solve(qu[i].u, qu[i].v);
}
for (int i = 1; i <= q; i++) write(ans[i]), puts("");
return 0;
}