树链剖分,就是将一棵树看成由多条链组成的,每条链就是所谓的重链,链与链之间通过树的一条边相连,这条边就是所谓的轻边。然后如果要用线段树,每次更新时,只要保证每条链上的点在线段树中的位置是连续的,便可以使用线段树的各种手段了。
至于为什么沿节点数最多的方向扩展重链,我猜测是因为节点数多容易被问到,那么这个地方的链越统一越好,所以沿这个方向扩展重链。
根本思想还是将树看成由多条链组成的,不要被线段树拘泥了思维,将树化成链之后,便天高任鸟飞了,不一定要用线段树去解决问题。
附两个题:
hdu3966 模板题:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define delf int m=(l+r)>>1
using namespace std;
const int MAX=50050;
int sum[MAX<<2];
int fa[MAX]; //父节点
int top[MAX]; //所在链的链头
int num[MAX]; //该节点为根子树的节点数
int son[MAX]; //子树最多的孩子编号
int pos[MAX]; //节点在线段数中的位置
int dep[MAX]; //节点深度
int v[MAX]; //节点初始值
int cnt;
vector <int> mv[MAX];
int n,m,p;
void init()
{
for (int i=1;i<=n;i++)
{
fa[i]=num[i]=son[i]=pos[i]=dep[i]=-1;
mv[i].clear();
}
cnt=0;
return ;
}
void dfs1(int u,int d,int f) //第一遍DFS,获得父节点,深度,孩子
{
fa[u]=f; //父亲
dep[u]=d; //深度
num[u]=1; //子树节点数
int s=mv[u].size(); //孩子数
for (int i=0;i<s;i++)
{
int next=mv[u][i];
if (next==f) //如果是父亲节点,跳过
continue ;
dfs1(next,d+1,u); //next的父亲为u,深度为d+1
if (son[u]==-1||num[next]>num[son[u]]) //更新孩子,要求是孩子中节点数最多的
son[u]=next;
num[u]+=num[next]; //更新该节点孩子数
}
return ;
}
void dfs2(int u,int t)
{
top[u]=t; //所在链的顶端
pos[u]=(++cnt); //在线段树中的位置
if (son[u]==-1) //没有孩子的情况
return ;
dfs2(son[u],t); //先扫描孩子的情况,孩子此时跟当前节点在一条链上
int s=mv[u].size(); //孩子数量
for (int i=0;i<s;i++)
{
int next=mv[u][i];
if (next==fa[u]||next==son[u]) //如果是父节点或者是跟当前点在一条链上,跳过
continue ;
dfs2(next,next); //以孩子节点为头,重开一条链
}
return ;
}
void pushdown(int l,int r,int rt) //线段树部分,向下更新
{
if (sum[rt]!=0)
{
sum[rt<<1]+=sum[rt];
sum[rt<<1|1]+=sum[rt];
sum[rt]=0;
}
return ;
}
void build(int l,int r,int rt)
{
sum[rt]=0;
if (l==r)
return ;
delf;
build (lson);
build (rson);
return ;
}
void update(int L,int R,int v,int l,int r,int rt)
{
if (L<=l&&r<=R)
{
sum[rt]+=v;
return ;
}
delf;
if (L<=m)
update(L,R,v,lson);
if (R>m)
update(L,R,v,rson);
return ;
}
int query(int k,int l,int r,int rt)
{
if (l==r)
return sum[rt];
pushdown(l,r,rt);
delf;
if (k<=m)
return query(k,lson);
return query(k,rson);
}
int main()
{
while (~scanf("%d%d%d",&n,&m,&p))
{
init();
for (int i=1;i<=n;i++)
scanf("%d",&v[i]);
for (int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
mv[a].push_back(b);
mv[b].push_back(a);
}
dfs1(1,0,0);
dfs2(1,1);
build(1,n,1);
while (p--)
{
char ch[2];
getchar();
scanf("%s",ch);
if (ch[0]!='Q')
{
int a,b,v;
scanf("%d%d%d",&a,&b,&v);
if (ch[0]=='D')
v=-v;
while (top[a]!=top[b])
{
if (dep[top[a]]>dep[top[b]])
{
update(pos[top[a]],pos[a],v,1,n,1);
a=fa[top[a]];
}
else
{
update(pos[top[b]],pos[b],v,1,n,1);
b=fa[top[b]];
}
}
if (dep[a]>dep[b])
update(pos[b],pos[a],v,1,n,1);
else
update(pos[a],pos[b],v,1,n,1);
}
else
{
int a;
scanf("%d",&a);
int ans=query(pos[a],1,n,1);
printf("%d\n",ans+v[a]);
}
}
}
}
hdu5044 这题就不能用线段树去解决,线段树会超时,其实你想,他只有一次询问,用线段树肯定大材小用,最经典的标记两端的方法肯定能解决问题,如果是线性,直接在一条链上标记,这题剖分后不就是多条链嘛,每条链标记一下,然后最后将所有链都给跑一遍不就成了。
代码:
#pragma comment(linker,"/STACK:102400000,102400000")
#include <iostream>
#include <cstring>
#include <cstdio>
#define ll long long int
using namespace std;
const int MAX=100010;
int top[MAX];
int fa[MAX];
int dep[MAX];
int num[MAX];
int head[MAX];
int son[MAX];
ll cn[MAX]; //点的变动
ll ce[MAX]; //边的变动
ll ansn[MAX]; //点的结果
ll anse[MAX]; //边的结果
int mark[MAX];
int n,m;
struct node
{
int id;
int to;
int next;
} edge[MAX<<1];
void init()
{
for (int i=1;i<=n;i++)
head[i]=top[i]=son[i]=fa[i]=num[i]=-1;
for (int i=1;i<=n;i++)
cn[i]=ce[i]=ansn[i]=anse[i]=0;
return ;
}
void add(int a,int b,int i)
{
edge[i].to=b;
edge[i].next=head[a];
edge[i].id=(i+1)/2;
head[a]=i;
return ;
}
void dfs1(int u,int d,int f)
{
fa[u]=f;
dep[u]=d;
num[u]=1;
int next=head[u];
while (next!=-1)
{
int to=edge[next].to;
if (to!=f)
{
mark[to]=edge[next].id; //to这个节点头上的边的编号
dfs1(to,d+1,u);
if (son[u]==-1||num[son[u]]<num[to])
son[u]=to;
num[u]+=num[to];
}
next=edge[next].next;
}
return ;
}
void dfs2(int u,int t)
{
top[u]=t;
if (son[u]==-1)
return ;
dfs2(son[u],t);
int next=head[u];
while (next!=-1)
{
int to=edge[next].to;
next=edge[next].next;
if (to==fa[u]||to==son[u])
continue ;
dfs2(to,to);
}
return ;
}
void swap(int &a,int &b)
{
int t=a;
a=b;
b=t;
return ;
}
void change_node(int a,int b,int v)
{
while (top[a]!=top[b])
{
if (dep[top[a]]<dep[top[b]])
swap(a,b);
cn[a]+=v;
cn[fa[top[a]]]-=v;
a=fa[top[a]];
}
if (dep[a]<dep[b])
swap(a,b);
cn[a]+=v;
cn[fa[b]]-=v;
return ;
}
void change_edge(int a,int b,int v)
{
while (top[a]!=top[b])
{
if (dep[top[a]]<dep[top[b]])
swap(a,b);
ce[a]+=v;
ce[fa[top[a]]]-=v;
a=fa[top[a]];
}
if (dep[a]<dep[b])
swap(a,b);
ce[a]+=v;
ce[b]-=v;
//cout<<ce[a]<<" "<<ce[b]<<endl;
return ;
}
void dfs(int u)
{
int next=head[u];
while (next!=-1)
{
int to=edge[next].to;
if (to!=fa[u])
{
dfs(to);
cn[u]+=cn[to];
ce[u]+=ce[to];
}
next=edge[next].next;
}
ansn[u]+=cn[u];
anse[mark[u]]+=ce[u];
return ;
}
int main()
{
int T;
scanf("%d",&T);
for (int r=1;r<=T;r++)
{
scanf("%d%d",&n,&m);
init();
for (int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b,i*2-1);
add(b,a,i*2);
}
dfs1(1,0,0);
dfs2(1,1);
for (int i=1;i<=m;i++)
{
char s[10];
int a,b,v;
scanf("%s%d%d%d",s,&a,&b,&v);
if (s[3]=='1')
change_node(a,b,v);
else
change_edge(a,b,v);
}
dfs(1);
printf("Case #%d:\n",r);
printf("%I64d",ansn[1]);
for (int i=2;i<=n;i++)
printf(" %I64d",ansn[i]);
printf("\n");
if (n>1)
printf("%I64d",anse[1]);
for (int i=2;i<n;i++)
printf(" %I64d",anse[i]);
printf("\n");
}
}