比较恶心的树链剖分的题,主要是在查询的时候注意链和链之间合并的时候去掉多余的分段。在线段树里面记录线段内有多少个分段,还有线段左右端点的颜色,合并的时候如果左边线段右端的颜色和右边线段左端的颜色相同的话总段数减一。
/**************************************************************
Problem: 2243
User: shllhsno1
Language: C++
Result: Accepted
Time:5856 ms
Memory:22336 kb
****************************************************************/
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
#define N 100500
#define ls l,mid,rt<<1
#define rs mid+1,r,rt<<1|1
#define mid ((l+r)>>1)
using namespace std;
int a[N],mp[N],rmp[N];
int idx;
int dep[N],son[N],fa[N],top[N],sz[N];
vector<int>e[N];
void dfs1(int u,int d)
{
dep[u] = d;
sz[u] = 1;
int Min = 0;
for(int i = 0; i<e[u].size(); i++)
{
int v = e[u][i];
if(v == fa[u])continue;
fa[v] = u;
dfs1(v,d+1);
sz[u]+=sz[v];
if(sz[v]>Min)
{
Min = sz[v];
son[u] = v;
}
}
}
void dfs2(int u,int tp)
{
mp[u] = ++idx;
rmp[idx] = u;
top[u] = tp;
if(son[u])dfs2(son[u],tp);
for(int i = 0; i<e[u].size(); i++)
{
int v = e[u][i];
if(v == fa[u]||v == son[u])continue;
dfs2(v,v);
}
}
struct node
{
int l,r,lc,rc,c,tp;
} tree[4*N];
void pushup(int rt)
{
tree[rt].lc = tree[rt<<1].lc;
tree[rt].rc = tree[rt<<1|1].rc;
tree[rt].c = tree[rt<<1].c+tree[rt<<1|1].c;
if(tree[rt<<1].rc == tree[rt<<1|1].lc)
tree[rt].c--;
if(tree[rt].c == 1)tree[rt].tp = tree[rt<<1].tp;
else tree[rt].tp = -1;
}
void pushdown(int rt)
{
if(tree[rt].tp!=-1)
{
tree[rt<<1].tp = tree[rt<<1|1].tp = tree[rt].tp;
tree[rt<<1].lc = tree[rt<<1].rc = tree[rt].tp;
tree[rt<<1|1].lc = tree[rt<<1|1].rc = tree[rt].tp;
tree[rt<<1].c = tree[rt<<1|1].c = 1;
tree[rt].tp = -1;
}
}
void build(int l,int r,int rt)
{
tree[rt].l = l;
tree[rt].r = r;
if(l == r)
{
tree[rt].lc = tree[rt].rc = tree[rt].tp = a[rmp[l]];
tree[rt].c = 1;
return;
}
build(ls);
build(rs);
pushup(rt);
}
void update(int l,int r,int rt,int L,int R,int tp)
{
if(L<=l&&R>=r)
{
tree[rt].tp = tree[rt].lc = tree[rt].rc = tp;
tree[rt].c = 1;
return;
}
pushdown(rt);
if(L<=mid)update(ls,L,R,tp);
if(R>mid)update(rs,L,R,tp);
pushup(rt);
}
pair<int,pair<int,int> > query(int l,int r,int rt,int L,int R)
{
if(L<=l&&R>=r)return make_pair(tree[rt].c,make_pair(tree[rt].lc,tree[rt].rc));
pushdown(rt);
if(L<=mid&&R>mid)
{
pair<int,pair<int,int> >t1,t2,t3;
t1=query(ls,L,R);
t2=query(rs,L,R);
t3.first = t1.first+t2.first;
t3.second.first = t1.second.first;
t3.second.second = t2.second.second;
if(t1.second.second == t2.second.first)t3.first--;
return t3;
}
if(R<=mid)return query(ls,L,R);
return query(rs,L,R);
}
void change(int a,int b,int tp)
{
int ta = top[a];
int tb = top[b];
while(ta!=tb)
{
if(dep[ta]<dep[tb])
{
swap(ta,tb);
swap(a,b);
}
update(1,idx,1,mp[ta],mp[a],tp);
a = fa[ta];
ta = top[a];
}
if(dep[a]>dep[b])swap(a,b);
update(1,idx,1,mp[a],mp[b],tp);
}
int get(int a,int b)
{
int ta = top[a];
int tb = top[b];
int la = -1,lb = -1,ret = 0;
pair<int,pair<int,int> >t;
while(ta!=tb)
{
if(dep[ta]<dep[tb])
{
swap(ta,tb);
swap(a,b);
swap(la,lb);
}
t = query(1,idx,1,mp[ta],mp[a]);
ret+=t.first;
if(t.second.second == la)ret--;
la = t.second.first;
a = fa[ta];
ta = top[a];
}
if(dep[a]>dep[b])swap(a,b),swap(la,lb);
t = query(1,idx,1,mp[a],mp[b]);
ret+=t.first;
if(t.second.first == la)ret--;
if(t.second.second == lb)ret--;
return ret;
}
int main()
{
int n,m,i,j;
while(scanf("%d%d",&n,&m)!=EOF)
{
idx = 0;
for(i = 1; i<=n; i++)
{
scanf("%d",&a[i]);
e[i].clear();
}
for(i = 1; i<n; i++)
{
int x,y;
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
fa[1] = -1;
memset(son,0,sizeof(son));
dfs1(1,0);
dfs2(1,1);
build(1,idx,1);
for(i = 1; i<=m; i++)
{
char op[2];
int x,y,z;
scanf("%s",op);
if(op[0] == 'Q')
{
scanf("%d%d",&x,&y);
printf("%d\n",get(x,y));
}
else
{
scanf("%d%d%d",&x,&y,&z);
change(x,y,z);
}
}
}
return 0;
}