题意: 知道了一颗有n个节点的树和树上每条边的权值,对应两种操作:
0 x 输出 当前节点到 x节点的最短距离,并移动到 x 节点位置
1 x val 把第 x 条边的权值改为 val。
思路:树链剖分基础题,这道题比较坑的地方是卡vector,用vector存边会tle。
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<queue>
#include<stack>
#include<string>
#include<map>
#include<set>
#include<ctime>
#define eps 1e-6
#define LL long long
#define pii pair<int, int>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
const int MAXN = 101000;
int sumv[MAXN*4];
struct Edge {
int to, next;
} edge[2*MAXN];
int head[MAXN];
int tot;
void Add(int from, int to) {
edge[tot].to = to;
edge[tot].next = head[from];
head[from] = tot ++;
}
int query(int o, int L, int R, int ql, int qr) {
int M = (L+R) >> 1, ans = 0;
if(ql<=L && qr>=R) return sumv[o];
if(ql <= M) ans += query(o*2, L, M, ql, qr);
if(M < qr) ans += query(o*2+1, M+1, R, ql, qr);
return ans;
}
void update(int o, int L, int R, int p, int v) {
int M = (L+R) >> 1;
if(L == R) sumv[o] = v;
else {
if(p <= M) update(o*2, L, M, p, v); else update(o*2+1, M+1, R, p, v);
sumv[o] = sumv[o*2] + sumv[o*2+1];
}
}
int n, TOT;
int siz[MAXN], son[MAXN], dep[MAXN], top[MAXN], fa[MAXN], pos[MAXN];
void init() {
memset (head, -1, sizeof(head));
tot = 0;
TOT = 0;
memset(sumv, 0, sizeof(sumv));
memset(son, 0, sizeof(son));
}
void dfs(int cur, int f) {
siz[cur] = 1;
int tmp = 0;
for(int i = head[cur]; i != -1; i = edge[i].next) {
int u = edge[i].to;
if(u == f) continue;
dep[u] = dep[cur] + 1;
fa[u] = cur;
dfs(u, cur);
siz[cur] += siz[u];
if(siz[u] > tmp) son[cur] = u, tmp = siz[u];
}
}
void dfs2(int cur, int tp) {
top[cur] = tp;
pos[cur] = ++TOT;
if(son[cur]) dfs2(son[cur], tp);
for(int i = head[cur]; i != -1; i = edge[i].next) {
int u = edge[i].to;
if(u==son[cur] || u==fa[cur]) continue;
dfs2(u, u);
}
}
int Find(int u, int v) {
int ans = 0;
int fu = top[u], fv = top[v];
while(fu != fv) {
if(dep[fu]<dep[fv]) swap(fu, fv), swap(u, v);
ans += query(1, 1, n, pos[fu], pos[u]);
u = fa[fu]; fu = top[u];
}
if(u==v) return ans;
else if(dep[u]<dep[v]) swap(u, v);
return ans+query(1, 1, n, pos[v]+1, pos[u]);
}
int q, s, e[MAXN][3];
int main() {
//freopen("input.txt", "r", stdin);
while(scanf("%d%d%d", &n, &q, &s) == 3) {
init();
for(int i = 1; i < n; i++) {
scanf("%d%d%d", &e[i][0], &e[i][1], &e[i][2]);
Add(e[i][0], e[i][1]);
Add(e[i][1], e[i][0]);
}
dfs(1, 0);
dfs2(1, 1);
for(int i = 1; i < n; i++) {
int& u = e[i][0];
int& v = e[i][1];
if(dep[u] < dep[v]) swap(u, v);
update(1, 1, n, pos[u], e[i][2]);
}
int op, tmp, tmp2;
for(int i = 1; i <= q; i++) {
scanf("%d", &op);
if(op == 0) {
scanf("%d", &tmp);
printf("%d\n", Find(s, tmp));
s = tmp;
}
else {
scanf("%d%d", &tmp, &tmp2);
update(1, 1, n, pos[e[tmp][0]], tmp2);
}
}
}
return 0;
}