题意: 知道了一颗有 n 个节点的树和树上每条边的权值,对应两种操作:
【LCA_RMQ+树状数组】
0 x 输出 当前节点到 x节点的最短距离,并移动到 x 节点位置
1 x val 把第 x 条边的权值改为 val
分析: 树上两个节点a,b的距离可以转化为
dis[a] + dis[b] - 2*dis[lca(a,b)]
其中 dis[i] 表示 i 节点到根的距离,
由于每次修改一条边,树中在这条边下方的 dis[] 值全都会受到影响,这样每条边都对应这一段这条边的管辖区,
可以深搜保存遍历该点的时间戳,ll[i] 表示第一次遍历到该点的时间戳, rr[i] 表示回溯到该点时的时间戳,这样每次
修改边 i 的时候就可以对区间 [ ll[i], rr[i] ] 进行成段更新,成段更新的方式可以在 位置 ll[i] 上加一个权值,在位置
rr[i]+1 上减去这个权值,求和时,sum(ll[i]) 即为该点到根的距离。
#include<stdio.h>
#include<string.h>
#include<math.h>
#define clr(x)memset(x,0,sizeof(x))
#define maxn 200005
struct node
{
int to,next,w,xu;
}e[1000000];
int tot;
int head[maxn];
void add(int s,int t,int wi,int xu)
{
e[tot].xu=xu;
e[tot].w=wi;
e[tot].to=t;
e[tot].next=head[s];
head[s]=tot++;
}
int dp[maxn<<1][18];
int x[maxn<<1];
int d[maxn];
int r[maxn];
int v[maxn];
int f[maxn];
int ll[maxn];
int rr[maxn];
int g[maxn];
int n,m;
int min(int i,int j)
{
return d[i]<d[j]?i:j;
}
void makermq(int nn)
{
int i,j;
for(i=0;i<nn;i++)
dp[i][0]=i;
for(j=1;(1<<j)<=nn;j++)
for(i=1;i+(1<<j)-1<nn;i++)
dp[i][j]=min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);
}
int rmq(int l,int r)
{
int k=(int)(log((r-l+1)*1.0)/log(2.0));
return min(dp[l][k],dp[r-(1<<k)+1][k]);
}
int cnt,ti;
void dfs(int u,int deep)
{
v[u]=1;
x[cnt]=u;
d[cnt]=deep;
r[u]=cnt++;
ll[u]=++ti;
int i,k;
for(i=head[u];i;i=e[i].next)
{
k=e[i].to;
if(!v[k])
{
g[e[i].xu]=k;
dfs(k,deep+1);
x[cnt]=u;
d[cnt++]=deep;
}
}
rr[u]=ti;
}
int tree[maxn];
int lowbit(int x)
{
return (x)&(-x);
}
void update(int pos,int x)
{
while(pos<=n)
{
tree[pos]+=x;
pos+=lowbit(pos);
}
}
int sum(int pos)
{
int s=0;
while(pos>0)
{
s+=tree[pos];
pos-=lowbit(pos);
}
return s;
}
int edge[maxn];
int val[maxn];
int main()
{
int i,st;
while(scanf("%d%d%d",&n,&m,&st)!=EOF)
{
int a,b,c;
clr(head); clr(v);
clr(f); clr(tree);
tot=1;
ti=-1;
for(i=1;i<n;i++)
{
scanf("%d%d%d",&a,&b,&c);
val[i]=c;
edge[i]=c;
add(a,b,c,i);
add(b,a,c,i);
}
cnt=0;
dfs(1,0);
makermq(2*n-1);
for(i=1;i<n;i++)
{
update(ll[g[i]],edge[i]);
update(rr[g[i]]+1,-edge[i]);
}
int op;
while(m--)
{
scanf("%d",&op);
if(op==1)
{
scanf("%d%d",&a,&b);
update(ll[g[a]],-val[a]);
update(rr[g[a]]+1,val[a]);
update(ll[g[a]],b);
update(rr[g[a]]+1,-b);
val[a]=b;
}
else
{
scanf("%d",&a);
int lca,d1,d2,d3;
if(r[st]<=r[a])
lca=x[rmq(r[st],r[a])];
else lca=x[rmq(r[a],r[st])];
d1=sum(ll[st]);
d2=sum(ll[a]);
d3=sum(ll[lca]);
st=a;
printf("%d\n",d1+d2-2*d3);
}
}
}
return 0;
}
这个是利用线段树维护区间和,把边尾所涉及到的时间戳区间内的值都加上对应权值,然后某个点的值就代表了它到根节点的距离!!
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 200010;
struct edge
{
int u, v, w, next;
}edges[maxn*2], e[maxn];
int E[maxn*2], H[maxn*2], I[maxn*2], L[maxn], R[maxn];
int dp[maxn*2][40];
int cnt, clock, dfn;
int first[maxn];
int a[maxn<<2];
int b[maxn];
int add[maxn<<2];
int degree[maxn];
int vis[maxn];
void AddEdge(int u, int v, int w)
{
edges[cnt].u = u;
edges[cnt].v = v;
edges[cnt].w = w;
edges[cnt].next = first[u];
first[u] = cnt++;
edges[cnt].u = v;
edges[cnt].v = u;
edges[cnt].w = w;
edges[cnt].next = first[v];
first[v] = cnt++;
}
void dfs(int u, int fa, int dep)
{
E[++clock] = u;
H[clock] = dep;
I[u] = clock;
L[u] = ++dfn;
b[dfn] = u;
for(int i = first[u]; i != -1; i = edges[i].next)
{
int v = edges[i].v;
if(v == fa)
continue;
if(vis[v])
continue;
vis[v] = true;
dfs(v, u, dep+1);
E[++clock] = u;
H[clock] = dep;
}
R[u] = dfn;
}
void RMQ_init(int n)
{
for(int i = 1; i <= n; i++)
dp[i][0] = i;
for(int j = 1; (1<<j) <= n; j++)
for(int i = 1; i+(1<<j)-1 <= n; i++)
{
if(H[dp[i][j-1]] < H[dp[i+(1<<(j-1))][j-1]])
dp[i][j] = dp[i][j-1];
else
dp[i][j] = dp[i+(1<<(j-1))][j-1];
}
}
int RMQ(int l, int r)
{
l = I[l], r = I[r];
if(l > r)
swap(l, r);
int len = r-l+1, k = 0;
while((1<<k) <= len)
k++;
k--;
if(H[dp[l][k]] < H[dp[r-(1<<k)+1][k]])
return E[dp[l][k]];
else
return E[dp[r-(1<<k)+1][k]];
}
void pushdown(int rt, int l, int r)
{
int k = (r-l+1);
if(add[rt])
{
a[rt<<1] += add[rt]*(k-(k>>1));
a[rt<<1|1] += add[rt]*(k>>1);
add[rt<<1] += add[rt];
add[rt<<1|1] += add[rt];
add[rt] = 0;
}
}
void build(int l, int r, int rt)
{
a[rt] = 0;
add[rt] = 0;
if(l == r)
return;
int m = (l + r) >> 1;
build(l, m, rt<<1);
build(m+1, r, rt<<1|1);
}
void update(int x, int y, int l, int r, int rt, int num)
{
if(l == x && r == y)
{
a[rt] += (r-l+1)*num;
add[rt] += num;
return;
}
pushdown(rt, l, r);
int m = (l + r) >> 1;
if(y <= m)
update(x, y, l, m, rt<<1, num);
else if(x > m)
update(x, y, m+1, r, rt<<1|1, num);
else
{
update(x, m, l, m, rt<<1, num);
update(m+1, y, m+1, r, rt<<1|1, num);
}
a[rt] = a[rt<<1] + a[rt<<1|1];
}
int query(int x, int l, int r, int rt)
{
if(l == r)
{
return a[rt];
}
pushdown(rt, l, r);
int m = (l + r) >> 1;
int ans = 0;
if(x <= m)
ans = query(x, l, m, rt<<1);
else
ans = query(x, m+1, r, rt<<1|1);
a[rt] = a[rt<<1] + a[rt<<1|1];
return ans;
}
int main()
{
int cas = 1;
int T;
//scanf("%d", &T);
int s, to, root, n, q;
while(scanf("%d %d %d", &n, &q, &s) != EOF)
{
memset(vis, 0, sizeof(vis));
memset(first, -1, sizeof(first));
memset(degree, 0, sizeof(degree));
clock = cnt = dfn = 0;
build(1, n, 1);
//for(int i = 1; i <= n; i++)
// scanf("%d", &b[i]);
for(int i = 1; i < n; i++)
{
int u, v, w;
scanf("%d %d %d", &u, &v, &w);
e[i].u = u;
e[i].v = v;
e[i].w = w;
AddEdge(u, v, 0);
degree[v]++;
}
for(int i = 1; i <= n; i++)
if(!degree[i])
{
vis[i] = true;
dfs(i, -1, 0);
root = i;
break;
}
RMQ_init(2*n-1);
//puts("1");
for(int i = 1; i < n; i++)
{
int u = e[i].u;
int v = e[i].v;
int w = e[i].w;
//printf("***%d %d\n", L[v], R[v]);
if(L[u] < L[v])
update(L[v], R[v], 1, n, 1, w);
else
update(L[u], R[u], 1, n, 1, w);
}
while(q--)
{
int x;
scanf("%d", &x);
if(!x)
{
scanf("%d", &to);
int d1 = query(L[s], 1, n, 1);
int d2 = query(L[to], 1, n, 1);
int lca = RMQ(s, to);
int d3 = query(L[lca], 1, n, 1);
//printf("***%d %d %d\n", d1, d2, d3);
printf("%d\n", d1+d2-2*d3);
//printf("%d\n", dfn);
s = to;
}
else
{
int i, w;
scanf("%d %d", &i, &w);
int x = w - e[i].w;
e[i].w = w;
int v = e[i].v;
int u = e[i].u;
if(L[u] < L[v])
update(L[v], R[v], 1, n, 1, x);
else
update(L[u], R[u], 1, n, 1, x);
}
}
}
return 0;
}
树链剖分
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
#define Del(a,b) memset(a,b,sizeof(a))
const int N = 100005;
int dep[N],siz[N],fa[N],id[N],son[N],val[N],top[N]; //top 最近的重链父节点
int num,head[N],cnt;
struct Edge
{
int to,next;
};
Edge v[N*2];
struct tree
{
int x,y,val;
void read(){
scanf("%d%d%d",&x,&y,&val);
}
};
void add_Node(int x,int y)
{
v[cnt].to = y;
v[cnt].next = head[x];
head[x] = cnt++;
}
tree e[N];
void dfs1(int u, int f, int d) {
dep[u] = d;
siz[u] = 1;
son[u] = 0;
fa[u] = f;
for(int i = head[u]; i != -1 ;i = v[i].next)
{
int to = v[i].to;
if(to != f)
{
dfs1(to,u,d+1);
siz[u] += siz[to];
if(siz[son[u]] < siz[to])
son[u] = to;
}
}
}
void dfs2(int u, int tp) {
top[u] = tp;
id[u] = ++num;
if (son[u]) dfs2(son[u], tp);
for(int i = head[u]; i != -1 ;i = v[i].next )
{
int to = v[i].to;
if(to == fa[u] || to == son[u])
continue;
dfs2(to,to);
}
}
#define lson(x) ((x<<1))
#define rson(x) ((x<<1)+1)
struct Tree
{
int l,r,val;
};
Tree tree[4*N];
void pushup(int x) {
tree[x].val = tree[lson(x)].val + tree[rson(x)].val;
}
void build(int l,int r,int v)
{
tree[v].l=l;
tree[v].r=r;
if(l==r){
tree[v].val = val[l];
return ;
}
int mid=(l+r)>>1;
build(l,mid,v*2);
build(mid+1,r,v*2+1);
pushup(v);
}
void update(int o,int v,int val) //log(n)
{
if(tree[o].l==tree[o].r)
{
tree[o].val = val;
return ;
}
int mid = (tree[o].l+tree[o].r)/2;
if(v<=mid)
update(o*2,v,val);
else
update(o*2+1,v,val);
pushup(o);
}
int query(int o,int l, int r)
{
if (tree[o].l >= l && tree[o].r <= r) {
return tree[o].val;
}
int mid = (tree[o].l + tree[o].r) / 2;
if(r<=mid)
return query(o+o,l,r);
else if(l>mid)
return query(o+o+1,l,r);
else
return query(o+o,l,mid) + query(o+o+1,mid+1,r);
}
int Yougth(int u, int v) {
int tp1 = top[u], tp2 = top[v];
int ans = 0;
while (tp1 != tp2) {
if (dep[tp1] < dep[tp2]) {
swap(tp1, tp2);
swap(u, v);
}
ans += query(1,id[tp1], id[u]);
u = fa[tp1];
tp1 = top[u];
}
if (u == v) return ans;
if (dep[u] > dep[v]) swap(u, v);
ans += query(1,id[son[u]], id[v]);
return ans;
}
int main()
{
int n,m,s;
while(~scanf("%d%d%d",&n,&m,&s))
{
cnt = 0;
memset(head,-1,sizeof(head));
for(int i=1;i<n;i++)
{
e[i].read();
add_Node(e[i].x,e[i].y);
add_Node(e[i].y,e[i].x);
}
num = 0;
dfs1(1,0,1);
dfs2(1,1);
for(int i=1;i<n;i++)
{
if(dep[e[i].x] < dep[e[i].y])
swap(e[i].x,e[i].y);
val[id[e[i].x]] = e[i].val;
}
build(1,num,1);
for(int i=0;i<m;i++)
{
int ok,x,y;
scanf("%d",&ok);
if(ok==0)
{
scanf("%d",&x);
printf("%d\n",Yougth(s,x));
s = x;
}
else
{
scanf("%d%d",&x,&y);
update(1,id[e[x].x],y);
}
}
}
return 0;
}