思路:单点的树链剖分题目.大致的思路与边问题的树链剖分大同小异.
只是在处理相邻的颜色的计算的时候需要仔细.wa了几发.
这边就写一下我处理的思路以及记得起来的wa点.
理解树链剖分之后,你会明白,树链剖分是按照路径两边往中间缩.
所以我用cu,cv标记两端的颜色.
当需要从u开始搜的时候,我就比对一下该端的上一个边缘颜色是否相同.假如相同的话必然减一.
好了,这个可以解决大部分问题.
剩下的还有一个问题.
就是当其中一个端点已经缩到根上的时候.并且也就是挑出循环之后我们所要做的处理.
这时候我们做的还是一样,这个时候因为只剩这么一个区间了(这个区间可能是1个点也可能是多个点)
所以我们要做的就是将这个区间的两端的边缘颜色与cu,cv对应比较,假如相同,则答案要减一.
其实上面的处理也是要与两端对应比较,但应该这个时候另外一端还没有缩过来,对应颜色为-1,肯定不同,所以就不用比了.
现在是我的WA点.我wa的数据为:
6 100
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
C 4 3 1
C 2 2 2
Q 4 1
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define MAX 100010
#define ls rt<<1
#define rs ls|1
#define m (l+r)>>1
int sum[MAX << 2], lb[MAX << 2], rb[MAX << 2], col[MAX << 2];
int son[MAX], size[MAX], pre[MAX], top[MAX], deep[MAX];
int cnt, pos;
int head[MAX], posx[MAX],posx2[MAX],co[MAX], e[MAX][2];
struct edge{
int v;
int next;
}edg[MAX << 1];
void addedge(int u, int v)
{
edg[cnt].v = v;
edg[cnt].next = head[u];
head[u] = cnt++;
}
void dfs(int u, int p, int d)
{
deep[u] = d;
size[u] = 1;
pre[u] = p;
for (int i = head[u]; i != -1; i = edg[i].next)
{
int v = edg[i].v;
if (v != p)
{
dfs(v, u, d + 1);
size[u] += size[v];
if (son[u] == -1 || size[v] > size[son[u]])
son[u] = v;
}
}
}
void getpos(int u, int p)
{
top[u] = p;
posx2[pos] = u;
posx[u] = pos++;
if (son[u] == -1)
return;
getpos(son[u], p);
for (int i = head[u]; i != -1; i = edg[i].next)
{
int v = edg[i].v;
if (v != son[u] && v != pre[u])
getpos(v, v);
}
}
//线段树
void uprt(int rt)
{
sum[rt] = sum[ls] + sum[rs];
sum[rt] -= (rb[ls] == lb[rs]);
lb[rt] = lb[ls];
rb[rt] = rb[rs];
}
void ups(int rt)
{
if (col[rt] + 1)
{
sum[ls] = sum[rs] = 1;
lb[ls] = rb[ls] = col[rt];
rb[rs] = lb[rs] = col[rt];
col[ls] = col[rt];
col[rs] = col[rt];
col[rt] = -1;
}
}
void build(int l, int r, int rt)
{
col[rt] = -1;
lb[rt] = -1;
rb[rt] = -1;
if (l == r)
{
lb[rt] = rb[rt] = co[posx2[l]];
sum[rt] = 1;
return;
}
int mid = m;
build(l, mid, ls);
build(mid + 1, r, rs);
uprt(rt);
}
void updata(int L, int R, int c, int l, int r, int rt)
{
if (L <= l&&r <= R)
{
col[rt] = c;
sum[rt] = 1;
lb[rt] = rb[rt] = c;
return;
}
ups(rt);
int mid = m;
if (L <= mid)
updata(L, R, c, l, mid, ls);
if (mid<R)
updata(L, R, c, mid + 1, r, rs);
uprt(rt);
}
int query(int L, int R, int l, int r, int rt)
{
if (L <= l&&r <= R)
return sum[rt];
ups(rt);
int mid = m;
int ans = 0;
if (L <= mid)
ans = query(L, R, l, mid, ls);
if (mid<R)
{
if (ans != 0)
ans -= (rb[ls] == lb[rs]);
ans += query(L, R, mid + 1, r, rs);
}
return ans;
}
int getco(int q, int l, int r, int rt)
{
if (l == r)
return lb[rt];
ups(rt);
int mid = m;
if (q <= mid)
return getco(q, l, mid, ls);
else
return getco(q, mid + 1, r, rs);
}
int solve(int u, int v)
{
int fu = top[u], fv = top[v];
int cu = -1, cv = -1;
int ans = 0;
while (fu != fv)
{
if (deep[fu] < deep[fv])
{
swap(fu, fv);
swap(u, v);
swap(cu, cv);
}
if (cu == getco(posx[u],1,pos-1,1))
ans--;
cu = getco(posx[fu], 1, pos - 1, 1);//这个区间的结束颜色
ans += query(posx[fu], posx[u], 1, pos - 1, 1);
u = pre[fu];
fu = top[u];
}
if (u == v)
{
ans++;
int ll = getco(posx[u], 1, pos - 1, 1);
if (cu == ll)
ans--;
if (cv == ll)
ans--;
return ans;
}
if (deep[u] > deep[v])
{
swap(u, v);
swap(cu, cv);
}
if (cu == getco(posx[u], 1, pos - 1, 1))
ans--;
cu = getco(posx[v], 1, pos - 1, 1);
ans += query(posx[u], posx[v], 1, pos - 1, 1);
if (cu == cv)
ans--;
return ans;
}
void solve(int u, int v, int c)
{
int fu = top[u], fv = top[v];
while (fu != fv)
{
if (deep[fu] < deep[fv])
{
swap(fu, fv);
swap(u, v);
}
updata(posx[fu], posx[u], c, 1, pos - 1, 1);
u = pre[fu];
fu = top[u];
}
if (u == v)
{
updata(posx[u], posx[u], c, 1, pos - 1, 1);
return;
}
if (deep[u]>deep[v])swap(u, v);
updata(posx[u], posx[v], c, 1, pos - 1, 1);
}
void init()
{
cnt = 0;
pos = 1;
memset(head, -1, sizeof(head));
memset(son, -1, sizeof(son));
}
int main()
{
int n, k;
while (~scanf("%d%d", &n, &k))
{
init();
for (int i = 1; i <= n; i++)
scanf("%d", &co[i]);
for (int i = 0; i < n - 1; i++)
{
scanf("%d%d", &e[i][0], &e[i][1]);
addedge(e[i][0], e[i][1]);
addedge(e[i][1], e[i][0]);
}
dfs(1, 1, 0);
getpos(1, 1);
build(1, pos - 1, 1);
char str[10];
int a, b, c;
for (int i = 0; i < k; i++)
{
scanf("%s", str);
if (str[0] == 'Q')
{
scanf("%d%d", &a, &b);
printf("%d\n", solve(a, b));
}
else
{
scanf("%d%d%d", &a, &b, &c);
solve(a, b, c);
}
}
}
}