树链剖分+线段树lazytag标记
染色
题目描述
给定一棵有 n 个节点的无根树和 m 个操作,操作共两类。
将节点 a 到节点 b 路径上的所有节点都染上颜色;
询问节点 a 到节点 b 路径上的颜色段数量,连续相同颜色的认为是同一段,例如 112221 由三段组成:11 、 222、1。
请你写一个程序依次完成操作。
输入格式
第一行包括两个整数 n,m,表示节点数和操作数;
第二行包含 n 个正整数表示 n 个节点的初始颜色;
接下来若干行包含两个整数 x 和 y,表示 x 和 y 之间有一条无向边;
接下来若干行每行描述一个操作:
C a b c 表示这是一个染色操作,把节点 a 到节点 b 路径上所有点(包括 a 和 b)染上颜色;
Q a b 表示这是一个询问操作,把节点 a 到节点 b 路径上(包括 a 和 b)的颜色段数量。
输出格式
对于每个询问操作,输出一行询问结果。
样例
Input Output
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Output
3
1
2
数据范围与提示
对于 100% 的数据,N,M≤105, 所有颜色 C 为整数且在 [0,109] 之间。
原题来自:HDU 6547
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
struct node{
int to,nxt;
}ed[maxn<<1];//原树边节点
int head[maxn],cnt;
int dep[maxn],sz[maxn],son[maxn],p[maxn],top[maxn],fa[maxn];//sz原树节点个数//p 线段树的dfs序也是线段树节点序
int initcol[maxn],ndcol[maxn],pos;//initol线段树节点对应的原树节点的权值(颜色)//原树节权值(颜色),pos线段树节点计数器也是dfs序计数器
void addedge(int u,int v){//添加原树边
ed[cnt].to=v;
ed[cnt].nxt=head[u];
head[u]=cnt++;
}
void dfs1(int u,int pre,int d){//当前结点,父节点,深度;算深度,大小,重儿子,父亲
dep[u]=d;
sz[u]=1;
fa[u]=pre;
for(int i=head[u];~i;i=ed[i].nxt)//这里~是按位取反,~0=-1,~-1=0;
{
int v=ed[i].to;
if(v!=pre)//无向边存图跳过指向父节点的边
{
dfs1(v,u,d+1);
sz[u]+=sz[v];
if(sz[son[u]]<sz[v])
son[u]=v;
}
}
}
void dfs2(int u,int tp){//原树中当前结点,当前结点的根,建立dfs序
top[u]=tp;
p[u]=pos++;//pos为线段树的节点总数
initcol[p[u]]=ndcol[u];//把原树节点的权值(颜色)赋给当前赋值dfs序的节点
if(son[u]==-1)//没有重儿子(没儿子)
return ;
dfs2(son[u],tp);//先dfs重儿子
for(int i=head[u];~i;i=ed[i].nxt){//遍历其他儿子
int v=ed[i].to;
if(v==fa[u]||v==son[u])
continue;//跳过父亲和重儿子
dfs2(v,v);
}
}
struct T{
int l,r,lcol,rcol,mark,num;//mark是lazytag,num是当前结点管理区间的颜色个数
}tree[maxn<<2];//线段树节点
void pushup(int k){//更新线段树节点
tree[k].lcol=tree[k<<1].lcol;
tree[k].rcol=tree[(k<<1)|1].rcol;
tree[k].num=tree[k<<1].num+tree[(k<<1)|1].num-(tree[k<<1].rcol==tree[(k<<1)|1].lcol);
return ;
}
void pushdown(int k){//lazytag下放
if(tree[k].l==tree[k].r)
return ;
if(tree[k].mark==0)
return ;
tree[k<<1].mark=tree[(k<<1)|1].mark=tree[k].mark;
tree[k<<1].lcol=tree[k<<1].rcol=tree[(k<<1)|1].lcol=tree[(k<<1)|1].rcol=tree[k].mark;
tree[k<<1].num=tree[(k<<1)|1].num=1;//更新当前区间的左右子区间颜色颜色个数
tree[k].mark=0;//取消标记
return ;
}
void build(int l,int r,int k){//建立线段树
tree[k].l=l,tree[k].r=r;
tree[k].lcol=initcol[l];//注意
tree[k].rcol=initcol[r];
tree[k].mark=0;
if(l==r){
tree[k].num=1;
return ;
}
int mid=l+r>>1;
build(l,mid,k<<1);
build(mid+1,r,(k<<1)|1);
pushup(k);
return ;
}
void update(int l,int r,int col,int k){//区间更新
if(tree[k].l==l&&tree[k].r==r){//区间符合就打标
tree[k].lcol=tree[k].rcol=col;
tree[k].num=1;
tree[k].mark=col;
return ;
}
pushdown(k);//否则就先tag下传,然后递归更新区间
int mid=tree[k].l+tree[k].r>>1;
if(r<=mid)
update(l,r,col,k<<1);
else if(l>mid)
update(l,r,col,(k<<1)|1);
else {
update(l,mid,col,k<<1);
update(mid+1,r,col,(k<<1)|1);
}
pushup(k);
return ;
}
int ask(int a,int k){//利用线段树找到线段树上a的确定颜色
if(a==tree[k].r)
return tree[k].rcol;
if(a==tree[k].l)
return tree[k].lcol;
pushdown(k);
int mid=tree[k].l+tree[k].r>>1;
if(a<=mid)
return ask(a,k<<1);
else if(a>mid)
return ask(a,(k<<1)|1);
}
int query(int l,int r,int k){//查询线段树l到r上的颜色块总数
if(tree[k].l==l&&tree[k].r==r)
return tree[k].num;
pushdown(k);
int mid=(tree[k].l+tree[k].r)>>1;
if(r<=mid)
return query(l,r,k<<1);
else if(l>mid)
return query(l,r,(k<<1)|1);
else
return query(l,mid,k<<1)+query(mid+1,r,(k<<1)|1)-(tree[k<<1].rcol==tree[(k<<1)|1].lcol);
}
void up(int u,int v,int col)//把原树上u到v的节点颜色换成col
{
int f1=top[u],f2=top[v];
while(f1!=f2){//u,v不在一条重连上
if(dep[f1]<dep[f2]){//让f1和u对应的重链的根的深度较大
swap(f1,f2);
swap(u,v);
}
update(p[f1],p[u],col,1);//在线段树上更新根深度较大的重链
u=fa[f1];
f1=top[u];
}//退出while的时候两个节点已经在一条链上了
if(dep[u]<dep[v])
swap(u,v);// 确保u的深度大因为在一条链上深度大的dfs序在后
update(p[v],p[u],col,1);
}
int getnum(int u,int v)//求解在u到v上的颜色段
{
int f1=top[u],f2=top[v];
int re=0;
while(f1!=f2){
if(dep[f1]<dep[f2]){
swap(f1,f2);
swap(u,v);
}
re+=query(p[f1],p[u],1)-(ask(p[fa[f1]],1)==ask(p[f1],1));//注意这里在换链时候的区块颜色
//另外因为有lazytag标记导致当前重链根的颜色和该根的父亲不一定是真正的颜色所以要用线段树先维护
u=fa[f1];
f1=top[u];
}
if(dep[u]<dep[v])
swap(u,v);
re+=query(p[v],p[u],1);
return re;
}
int main()
{
memset(son,-1,sizeof(son));
memset(head,-1,sizeof(head));
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&ndcol[i]);
}
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
dfs1(1,0,0);
dfs2(1,1);
build(0,n,1);
while(m--){
char str[2];
int a,b,c;
scanf("%s%d%d",str,&a,&b);
if(str[0]=='C'){
scanf("%d",&c);
up(a,b,c);
}
else {
printf("%d\n",getnum(a,b));
}
getchar();
}
return 0;
}
2021.7.19