知识储备:线段树、图的存储与遍历、状态压缩
题意翻译
你有一棵以1为根的有根树,有n个点,每个节点初始有一个颜色c[i]。
有两种操作:
1 v c 将以v为根的子树中所有点颜色更改为c
2 v 查询以v为根的子树中的节点有多少种不同的颜色
题目链接:https://www.luogu.org/problem/CF620E
首先我们讲一下dfs序在这道题里面的应用
定义:一棵树从根节点开始进行深度优先搜索,用一个时间戳记录下来每一个点被访问的时间,得到的序列就叫dfs序。
注意:以不同方式存树会有不同的dfs序,但是对于同一个邻接表搜出来的dfs序永远一样
这棵树的dfs序(不唯一)就是:1 4 6 3 7 10 5 8 2 9
这样的话我们就可以把一个点的dfs序代表这个点,这样不论树的形状是怎样的,dfs序都可以把它转化成线性结构
我们通过dfs把这颗树的dfs序存储在pos数组中
同时,我们还要记录一个点的入点时间戳与出点时间戳。因为题目要求我们支持更改一整个子树,如果我们把它转化成线性结构就可以用线段树维护它而不用dfs了,但是线段树要有左右端点,而这个左右端点就是用in与out数组实现。
dfs代码:
我是用邻接表存图,tot为时间,pos记录当前时间访问到的结点。
void dfs(int x,int fa){//用dfs序将树形结构转为线性结构
tot++;
in[x]=tot;
pos[tot]=x;
for(int i=head[x];i+1;i=e[i].next){
int k=e[i].to;
if(k==fa)continue;
dfs(k,x);
}
out[x]=tot;//其实我们out记录的是该子树中dfs序最大结点的dfs序,所以tot不加一
return;
}
这样的话,我们就可以用pos数组建树了!
struct tree{
long long sum;
long long tag;//记录要修改为哪个颜色
int l,r;
}t[1600040];
我们把颜色状态压缩到一个long long中,然后要建立一颗与运算的线段树:
首先是上传、下传以及建树:
void pushup(int rt){
t[rt].sum=t[rt<<1].sum|t[rt<<1|1].sum;
return;
}
void pushdown(int rt){
if(t[rt].tag!=0){//或运算不像区间和,这里0是影响答案的
t[rt<<1].sum=t[rt].tag;
t[rt<<1].tag=t[rt].tag;
t[rt<<1|1].sum=t[rt].tag;
t[rt<<1|1].tag=t[rt].tag;
t[rt].tag=0;
}
return;
}
void build(int rt,int l,int r){
t[rt].l=l;
t[rt].r=r;
if(l==r){
t[rt].sum=(long long)1<<(c[pos[l]]);//
t[rt].tag=0;
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
return;
}
然后,我们分别来处理两种操作
首先使修改操作:
将以x结点为根的子树的颜色修改为p:
void modify(int rt,int l,int r,int p){//将dfs序从l到r这一区间的颜色改为p 注意状压
if(t[rt].l>=l&&t[rt].r<=r){
t[rt].sum=(long long)1<<p;//
t[rt].tag=(long long)1<<p;//
return ;
}
int mid=(t[rt].l+t[rt].r)>>1;
pushdown(rt);
if(l<=mid)modify(rt<<1,l,r,p);
if(r>mid)modify(rt<<1|1,l,r,p);
pushup(rt);
return;
}
还记得我一开始的dfs过程吗?我们每访问一个点就记录它的入点时间戳,然后遍历它每一个儿子,然后记录出点时间戳,这样的话一个点的入、出点时间戳之间就包含了它的所有孩子。(这有点像笛卡尔树中序遍历就是一段区间)。
所以l和r就分别是in[x]和out[x]
然后是查询操作:
long long query(int rt,int l,int r){//查询dfs序从l到r的颜色个数
if(t[rt].l>=l&&t[rt].r<=r){
return t[rt].sum;
}
pushdown(rt);
int mid=(t[rt].l+t[rt].r)>>1;
long long rec=0;
if(l<=mid)rec|=query(rt<<1,l,r);
if(r>mid)rec|=query(rt<<1|1,l,r);
pushup(rt);
return rec;
}
但要注意的是,我们返回的是一个状压之后的数字,所以我们还要分解它,这里有两种写法:
第一种比较直观,但有一些慢
while(res) {
s += res&1;
res >>= 1;
}
实际思想就是按位与记录答案
还有一种是用lowbit,没学过树状数组的话可以先去看一下树状数组。
返回某一个数二进制下第一个1所代表的值,注意:是值,不是位置
long long lowbit(long long x){
return x&(-x);
}
for(long long j=num;j>0;j-=lowbit(j))ans++;
然后接结束了!
下附AC代码:
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
int read(){
char s;
int x=0,f=1;
s=getchar();
while(s<'0'||s>'9'){
if(s=='-')f=-1;
s=getchar();
}
while(s>='0'&&s<='9'){
x*=10;
x+=s-'0';
s=getchar();
}
return x*f;
}
struct tree{
long long sum;
long long tag;//记录要修改为哪个颜色
int l,r;
}t[1600040];
int c[400040];
int tot=0;
int pos[400040];//时间戳所对应的点
int in[400040];//入点时间戳
int out[400040];//出点时间戳
struct edge{
int to,next;
}e[800080];
int eid=0;
int head[400040];
void insert(int u,int v){
eid++;
e[eid].to=v;
e[eid].next=head[u];
head[u]=eid;
}
void dfs(int x,int fa){//用dfs序将树形结构转为线性结构
tot++;
in[x]=tot;
pos[tot]=x;
for(int i=head[x];i+1;i=e[i].next){
int k=e[i].to;
if(k==fa)continue;
dfs(k,x);
}
out[x]=tot;//其实我们out记录的是该子树中dfs序最大结点的dfs序,所以tot不加一
return;
}
void pushup(int rt){
t[rt].sum=t[rt<<1].sum|t[rt<<1|1].sum;
return;
}
void pushdown(int rt){
if(t[rt].tag!=0){//或运算不像区间和,这里0是影响答案的
t[rt<<1].sum=t[rt].tag;
t[rt<<1].tag=t[rt].tag;
t[rt<<1|1].sum=t[rt].tag;
t[rt<<1|1].tag=t[rt].tag;
t[rt].tag=0;
}
return;
}
void build(int rt,int l,int r){
t[rt].l=l;
t[rt].r=r;
if(l==r){
t[rt].sum=(long long)1<<(c[pos[l]]);//
t[rt].tag=0;
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
return;
}
void modify(int rt,int l,int r,int p){//将dfs序从l到r这一区间的颜色改为p 注意状压
if(t[rt].l>=l&&t[rt].r<=r){
t[rt].sum=(long long)1<<p;//
t[rt].tag=(long long)1<<p;//
return ;
}
int mid=(t[rt].l+t[rt].r)>>1;
pushdown(rt);
if(l<=mid)modify(rt<<1,l,r,p);
if(r>mid)modify(rt<<1|1,l,r,p);
pushup(rt);
return;
}
long long query(int rt,int l,int r){//查询dfs序从l到r的颜色个数
if(t[rt].l>=l&&t[rt].r<=r){
return t[rt].sum;
}
pushdown(rt);
int mid=(t[rt].l+t[rt].r)>>1;
long long rec=0;
if(l<=mid)rec|=query(rt<<1,l,r);
if(r>mid)rec|=query(rt<<1|1,l,r);
pushup(rt);
return rec;
}
int n,m;
long long lowbit(long long x){
return x&(-x);
}
int main(){
memset(head,-1,sizeof(head));
n=read();
m=read();
for(int i=1;i<=n;i++){
c[i]=read();
}
for(int i=1;i<n;i++){
int x,y;
x=read();
y=read();
insert(x,y);
insert(y,x);
}
dfs(1,0);
build(1,1,n);
for(int i=1;i<=m;i++){
int a;
a=read();
if(a==1){
int x,y;
x=read();
y=read();
modify(1,in[x],out[x],y);
//cout<<query(1,in[x],in[x])<<endl;
}
else{
int x;
x=read();
long long num=query(1,in[x],out[x]);
int ans=0;
//cout<<num<<endl;
for(long long j=num;j>0;j-=lowbit(j))ans++;
cout<<ans<<endl;
}
}
return 0;
}