题目大意:给出一棵树,和它上面每个点的初始颜色。有两种操作,1:求从x到y一共有多少段颜色(连续相同的颜色算一个颜色段)。2:把x到y路径上都染色成z。
思路:是一棵树,求LCA路径,迅速想到树链剖分。难点是维护区间合并问题。线段树上的区间合并很常规,正常做就可以,注意一下在从一个重链上跳到另一个重链的时候的区间合并。有两种解决方案。1:在线段树中询问的时候返回一个结构体,里面存着左边颜色,右边颜色,总共的颜色段数量。然后在外面判断。这样做比较快,但是实现起来比较麻烦。2:在跳重链的时候直接在线段树中单点查询本链在最顶端的顶点的颜色,和下一条重链开始的节点颜色直接判断。实现起来比较简单,但是会增加PushDown的次数。不过毕竟是均摊logn的复杂度,放心些不会T的。
CODE:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define MAX 200010
#define LEFT (pos << 1)
#define RIGHT (pos << 1|1)
using namespace std;
struct Complex{
int l_col,r_col,cnt_col;
int flag;
}tree[MAX << 2];
int points,asks;
int src[MAX],color[MAX];
int head[MAX],total;
int next[MAX],aim[MAX];
int deep[MAX],son[MAX],father[MAX];
int pos[MAX],top[MAX],cnt;
char flag[10];
inline void Add(int x,int y);
int PreDFS(int x,int last);
void DFS(int x,int last,int root);
void BuildTree(int l,int r,int pos);
inline void PushUp(int pos);
inline void PushDown(int pos);
inline void Modify(int x,int y,int c);
void Modify(int l,int r,int x,int y,int pos,int c);
inline int Ask(int x,int y);
int Ask(int l,int r,int x,int y,int pos);
int Ask(int l,int r,int x,int pos);
int main()
{
cin >> points >> asks;
for(int i = 1;i <= points; ++i)
scanf("%d",&src[i]);
for(int x,y,i = 1;i < points; ++i) {
scanf("%d%d",&x,&y);
Add(x,y),Add(y,x);
}
PreDFS(1,0);
DFS(1,0,1);
BuildTree(1,points,1);
for(int x,y,z,i = 1;i <= asks; ++i) {
scanf("%s",flag);
if(flag[0] == 'C') {
scanf("%d%d%d",&x,&y,&z);
Modify(x,y,z);
}
else {
scanf("%d%d",&x,&y);
printf("%d\n",Ask(x,y));
}
}
return 0;
}
inline void Add(int x,int y)
{
next[++total] = head[x];
aim[total] = y;
head[x] = total;
}
int PreDFS(int x,int last)
{
int re = 1,_max = 0;
deep[x] = deep[last] + 1;
father[x] = last;
for(int i = head[x];i;i = next[i]) {
if(aim[i] == last) continue;
int temp = PreDFS(aim[i],x);
if(temp > _max) _max = temp,son[x] = aim[i];
re += temp;
}
return re;
}
void DFS(int x,int last,int root)
{
pos[x] = ++cnt;
top[x] = root;
color[pos[x]] = src[x];
if(son[x]) DFS(son[x],x,root);
for(int i = head[x];i;i = next[i]) {
if(aim[i] == last || aim[i] == son[x]) continue;
DFS(aim[i],x,aim[i]);
}
}
void BuildTree(int l,int r,int pos)
{
tree[pos].flag = -1;
if(l == r) {
tree[pos].l_col = tree[pos].r_col = color[l];
tree[pos].cnt_col = 1;
return ;
}
int mid = (l + r) >> 1;
BuildTree(l,mid,LEFT);
BuildTree(mid + 1,r,RIGHT);
PushUp(pos);
}
inline void PushUp(int pos)
{
tree[pos].l_col = tree[LEFT].l_col;
tree[pos].r_col = tree[RIGHT].r_col;
tree[pos].cnt_col = tree[LEFT].cnt_col + tree[RIGHT].cnt_col - (tree[LEFT].r_col == tree[RIGHT].l_col);
}
inline void PushDown(int pos)
{
if(tree[pos].flag != -1) {
tree[LEFT].l_col = tree[LEFT].r_col = tree[LEFT].flag = tree[pos].flag;
tree[RIGHT].l_col = tree[RIGHT].r_col = tree[RIGHT].flag = tree[pos].flag;
tree[LEFT].cnt_col = tree[RIGHT].cnt_col = 1;
tree[pos].flag = -1;
}
}
inline void Modify(int x,int y,int c)
{
while(top[x] != top[y]) {
if(deep[top[x]] < deep[top[y]])
swap(x,y);
Modify(1,cnt,pos[top[x]],pos[x],1,c);
x = father[top[x]];
}
if(deep[x] < deep[y]) swap(x,y);
Modify(1,cnt,pos[y],pos[x],1,c);
}
void Modify(int l,int r,int x,int y,int pos,int c)
{
if(l == x && y == r) {
tree[pos].l_col = tree[pos].r_col = c;
tree[pos].cnt_col = 1;
tree[pos].flag = c;
return ;
}
PushDown(pos);
int mid = (l + r) >> 1;
if(y <= mid) Modify(l,mid,x,y,LEFT,c);
else if(x > mid) Modify(mid + 1,r,x,y,RIGHT,c);
else {
Modify(l,mid,x,mid,LEFT,c);
Modify(mid + 1,r,mid + 1,y,RIGHT,c);
}
PushUp(pos);
}
inline int Ask(int x,int y)
{
int re = 0;
while(top[x] != top[y]) {
if(deep[top[x]] < deep[top[y]])
swap(x,y);
re += Ask(1,cnt,pos[top[x]],pos[x],1);
int col1 = Ask(1,cnt,pos[top[x]],1);
int col2 = Ask(1,cnt,pos[father[top[x]]],1);
re += (col1 == col2 ? -1:0);
x = father[top[x]];
}
if(deep[x] < deep[y]) swap(x,y);
re += Ask(1,cnt,pos[y],pos[x],1);
return re;
}
int Ask(int l,int r,int x,int y,int pos)
{
if(l == x && y == r)
return tree[pos].cnt_col;
PushDown(pos);
int mid = (l + r) >> 1;
if(y <= mid) return Ask(l,mid,x,y,LEFT);
if(x > mid) return Ask(mid + 1,r,x,y,RIGHT);
int left = Ask(l,mid,x,mid,LEFT);
int right = Ask(mid + 1,r,mid + 1,y,RIGHT);
return left + right + (tree[LEFT].r_col == tree[RIGHT].l_col ? -1:0);
}
int Ask(int l,int r,int x,int pos)
{
if(l == r) return tree[pos].l_col;
PushDown(pos);
int mid = (l + r) >> 1;
if(x <= mid) return Ask(l,mid,x,LEFT);
return Ask(mid + 1,r,x,RIGHT);
}