树链剖分用一句话概括就是:把一棵树剖分为若干条链,然后利用数据结构(树状数组,SBT,Splay,线段树等等)去维护每一
void dfs1(int u,int father,int d)
{
dep[u]=d;
fa[u]=father;
siz[u]=1;
for(int i=head[u];~i;i=next[i])
{
int v=to[i];
if(v!=father)
{
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(son[u]==-1||siz[v]>siz[son[u]])
son[u]=v;
}
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
tid[u]=++tim;
rank[tid[u]]=u;
if(son[u]==-1) return;
dfs2(son[u],tp);
for(int i=head[u];~i;i=next[i])
{
int v=to[i];
if(v!=son[u]&&v!=fa[u])
dfs2(v,v);
}
}
Time Limit: 5000MS | Memory Limit: 131072K | |
Total Submissions: 6981 | Accepted: 1913 |
Description
You are given a tree with N nodes. The tree’s nodes are numbered 1 through N and its edges are numbered 1 through N − 1. Each edge is associated with a weight. Then you are to execute a series of instructions on the tree. The instructions can be one of the following forms:
CHANGE i v | Change the weight of the ith edge to v |
NEGATE a b | Negate the weight of every edge on the path from a to b |
QUERY a b | Find the maximum weight of edges on the path from a to b |
Input
The input contains multiple test cases. The first line of input contains an integer t (t ≤ 20), the number of test cases. Then follow the test cases.
Each test case is preceded by an empty line. The first nonempty line of its contains N (N ≤ 10,000). The next N − 1 lines each contains three integers a, b and c, describing an edge connecting nodes a and b with weight c. The edges are numbered in the order they appear in the input. Below them are the instructions, each sticking to the specification above. A lines with the word “DONE
” ends the test case.
Output
For each “QUERY
” instruction, output the result on a separate line.
Sample Input
1 3 1 2 1 2 3 2 QUERY 1 2 CHANGE 1 3 QUERY 1 2 DONE
Sample Output
1 3
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <vector>
#define INF 0x3f3f3f3f
#define MAX_N 10010
#define find_max(a,b) a>b?a:b
#define find_min(a,b) a>b?b:a
using namespace std;
int n;
struct edge{
int next;
int val;
};
struct num_edge{
int u,v,p;
};
int num=0;//编号变量
int start;//在线段树中的起始
int size[MAX_N];//用来保存以x为根的子树节点个数
int top[MAX_N];//用来保存当前节点的所在链的顶端节点
int son[MAX_N];//用来保存重儿子
int depth[MAX_N];//用来保存当前节点的重链的深度
int fa[MAX_N];//用来保存当前节点的父亲
int tid[MAX_N];//用来保存树中每个节点剖分后的新编号
int rank[MAX_N];//用来保存线段树中各位置对应的节点
int dat[4*MAX_N];//区间最大值线段树数组
int val[MAX_N];//用来保存该点到父亲的边权值
vector<edge> g[MAX_N];//存储节点的边
vector<num_edge> list;//按输入顺序存储边
void dat_change(int,int);
int get_count()
{//对树节点的个数取对数,小数点进位
int count=0;
int t=n;
while(t)
{
t/=2;
++count;
}
return count;
}
void init()
{//初始化
num=0;//标号置0
list.clear();//清空编号
for(int i=0;i<=n;++i)
{//清空边
g[i].clear();
son[i]=-1;
}
start=(1<<get_count())-1;//线段树用
for(int i=0;i<=start*2+1;++i)//把线段树的数组元素置为负无穷
dat[i]=-INF;
}
void add_edge(int u,int v,int p)
{//添边
g[u].push_back((edge){v,p});
g[v].push_back((edge){u,p});
}
void first_dfs(int u,int father)
{//第一深搜,确定每个节点的:重儿子、父亲、值
fa[u]=father;
size[u]=1;
for(int i=0;i<g[u].size();++i)
{
int v=g[u][i].next;
if(v!=father)
{
val[v]=g[u][i].val;
first_dfs(v,u);
size[u]+=size[v];
if(son[u]==-1||size[v]>size[son[u]])//son记录重儿子
son[u]=v;
}
}
}
void second_dfs(int u,int _top)
{//第二次深搜,确定每个节点的:重链顶部、标号、反标号、深度
top[u]=_top;
tid[u]=++num;
dat_change(u,val[u]);
rank[tid[u]]=u;
if(son[u]==-1)
return;
depth[son[u]]=depth[u];
second_dfs(son[u],_top);//优先搜索重儿子
for(int i=0;i<g[u].size();++i)
{//其后再搜索轻儿子
int v=g[u][i].next;
if(v!=son[u]&&v!=fa[u])
{
depth[v]=depth[u]+1;
second_dfs(v,v);
}
}
}
int get_lca(int a,int b)
{//寻找从a到b路径的LCA
int u=depth[a]>=depth[b]?a:b;
int v=depth[a]>=depth[b]?b:a;
while(depth[u]>depth[v])//深度齐平
u=fa[top[u]];
//上溯至同一重链
while(top[u]!=top[v])
{
u=fa[top[u]];
v=fa[top[v]];
}
//在同一重链中,标号是连续的,且标号小的为祖先,所以标号小的肯定为LCA
int lca=tid[u]>tid[v]?v:u;
return lca;
}
void dat_change(int a,int b)
{//线段树上的更新
int ndat=start+tid[a];
dat[ndat]=b;
while(ndat>0)
{
ndat/=2;
dat[ndat]=find_max(dat[ndat*2],dat[ndat*2+1]);
}
}
void change(int i,int val)
{//修改编号为i的边的值
int u=list[i-1].u;
int v=list[i-1].v;
if(fa[u]==v)
dat_change(u,val);
else dat_change(v,val);
list[i-1].p=val;
}
void negate(int a,int b)
{//对a到b路径上的边的值变反并更新
int lca=get_lca(a,b);
while(a!=lca)
{
dat_change(a,-1*dat[start+tid[a]]);
a=fa[a];
}
while(b!=lca)
{
dat_change(b,-1*dat[start+tid[b]]);
b=fa[b];
}
//change(lca,-1*dat[start+tid[lca]]);
}
int dat_query(int a,int b,int l,int r,int k)
{//线段树上的查询
// printf("a:%d b:%d l:%d r:%d k:%d\n",a,b,l,r,k);
//在查询区间外
if(r<a||b<l)
return -INF;
//在查询区间内
if(a<=l&&r<=b)
return dat[k];
else
{//与查询区间有交集
int vl=dat_query(a,b,l,(l+r)/2,2*k);
int vr=dat_query(a,b,(l+r)/2+1,r,2*k+1);
return find_max(vl,vr);
}
}
int query(int a,int b)
{
int res=-INF;
int lca=get_lca(a,b);
// printf("lca:%d\n", lca);
// printf("1\n");
while(top[a]!=top[lca])
{
res=find_max(res,dat_query(tid[top[a]],tid[a],1,start+1,1));
// printf("tid: %d - %d\n", tid[top[a]],tid[a]);
// printf("#tree: %d - %d\n", top[a],a);
a=fa[top[a]];
}
// printf("2\n");
while(top[b]!=top[lca])
{
res=find_max(res,dat_query(tid[top[b]],tid[b],1,start+1,1));
// printf("tid: %d - %d\n", tid[top[b]],tid[b]);
// printf("#tree: %d - %d\n", top[b],b);
b=fa[top[b]];
}
int u=tid[a]>tid[b]?b:a;
int v=tid[a]>tid[b]?a:b;
// printf("3 # u:%d v:%d\n",u,v);
if(u!=v)
{
u=son[u];
res=find_max(res,dat_query(tid[u],tid[v],1,start+1,1));
//printf("tid: %d - %d\n", tid[u],tid[v]);
// printf("#tree: %d - %d\n", u,v);
}
return res;
}
void solve()
{
int root=1;
//进行树链剖分
first_dfs(root,-1);
second_dfs(root,root);
// printf("start:%d \n",start);
// for(int i=1;i<=n;++i)
// {
// printf("size[%d]:%d ",i,size[i] );
// printf("top[%d]:%d ",i, top[i]);
// printf("son[%d]:%d ",i, son[i]);
// printf("fa[%d]:%d ", i,fa[i]);
// printf("tid[%d]:%d ", i,tid[i]);
// printf("val[%d]:%d ", i,val[i]);
// printf("\n");
// }
// for(int i=1;i<=2*start+1;++i)
// printf(" i:%d %d \n",i,dat[i]);
//处理改、查
char str[10];
char choice[4][10]={"QUERY","NEGATE","CHANGE","DONE"};
int a,b;
scanf("%s",str);
while(strcmp(str,choice[3]))
{
scanf("%d %d",&a,&b);
if(!strcmp(str,choice[0]))
printf("%d\n",query(a,b));
else if(!strcmp(str,choice[1]))
negate(a,b);
else if(!strcmp(str,choice[2]))
change(a,b);
scanf("%s",str);
}
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
scanf("%d",&n);
init();
int u,v,p;
for(int i=0;i<n-1;++i)
{
scanf("%d %d %d",&u,&v,&p);
list.push_back((num_edge){u,v,p});
add_edge(u,v,p);
}
solve();
}
return 0;
}
/*
100
14
1 2 3
1 3 4
1 4 5
2 5 3
2 6 4
3 7 5
4 8 2
4 9 10
5 10 7
6 11 4
6 12 4
8 13 20
13 14 13
*/