有一个 N 个节点的树,每条边有颜色、边权。
您需要处理 𝑄Q 个询问,每个询问给出 𝑥𝑖,𝑦𝑖,𝑢𝑖,𝑣𝑖,您需要求出假定所有颜色为 𝑥𝑖 的边边权全部变成 𝑦𝑖 后,𝑢𝑖 和 𝑣𝑖 之间的距离。询问之间互相独立。
思路 :发现树上路径就是 LCA 到两个端点的距离 颜色不同 因为颜色数量过多 没法状压 考虑对不同颜色建立主席树 每个节点是不同的颜色 查询颜色k 返回的答案即为 tr[u].tot * c - tr[u].sum
tr[u].tot 为当前根的颜色k的节点数量 c为被替换的值 sum为之前的和 LCA可以考虑倍增去处理
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ULL;
using LL = long long;
constexpr int N = 3e5 + 5, mod = 998244353;
constexpr double eps = 1e-8;
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")
#define fi first
#define se second
#define int long long
#define lowbit(x) (x & (-x))
#define PII pair<int, int>
#define mid ((l + r) >> 1)
int min(int a, int b) { return a < b ? a : b; }
int max(int a, int b) { return a > b ? a : b; }
int ksm(int a, int b){
int res = 1;
while(b){
if(b & 1)res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
struct Node{
int lc, rc, tot, sum;
}tr[N * 40];
int n, q;
vector<tuple<int, int, int>>g[N];
int dep[N], fa[N][21], root[N], dist[N], cnt;
void update(int &u, int v, int l, int r, int c, int d){
u = ++ cnt;
tr[u] = tr[v];
tr[u].tot ++;
tr[u].sum += d;
if(l == r)return;
if(c <= mid) update(tr[u].lc, tr[v].lc, l, mid, c, d);
else update(tr[u].rc, tr[v].rc, mid + 1, r, c, d);
}
int query(int u, int l, int r, int c, int w){
if(l == r)return tr[u].tot * w - tr[u].sum;
if(c <= mid)return query(tr[u].lc, l, mid, c, w);
else return query(tr[u].rc, mid + 1, r, c, w);
}
void dfs(int u, int f){
dep[u] = dep[f] + 1;
fa[u][0] = f;
for(int i = 1; i <= 20; ++ i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(auto [v, c, w] : g[u]){
if(v == f)continue;
dist[v] = dist[u] + w;
update(root[v], root[u], 1, n, c, w);
dfs(v, u);
}
}
int LCA(int x, int y){
if(dep[x] < dep[y])swap(x, y);
for(int i = 20; ~i; -- i){
if(dep[x] - (1LL << i) >= dep[y]){
x = fa[x][i];
}
}
if(x == y)return x;
for(int i = 20; ~i; -- i){
if(fa[x][i] != fa[y][i]){
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
void Sakuya()
{
cin >> n >> q;
for(int i = 1; i <= n - 1; ++ i){
int u, v, c, w;
cin >> u >> v >> c >> w;
g[u].emplace_back(v, c, w);
g[v].emplace_back(u, c, w);
}
dfs(1, 1);
while(q --){
int c, w, u, v;
cin >> c >> w >> u >> v;
int lca = LCA(u, v);
cout << dist[u] + dist[v] - 2 * dist[lca] + query(root[u], 1, n, c, w) + query(root[v], 1, n, c, w) - 2 * query(root[lca], 1, n, c, w) << "\n";
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
// int T;
// for (cin >> T; T -- ; )
Sakuya();
}