问你一段路上,不同数字的连续段有多少
就是连续的相同的数字要合并,这就是典型的区间合并问题,然后加上区间赋值
然后每一段查询的时候,区间右端点都要和原来的数字比一下是否相同
最后两个点都在重链上的时候,最后一个区间,不仅右端点要比较,左端点也要比较是否相同
代码:
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
#define MAX 40005
#define MAXN 6005
#define maxnode 10
#define sigma_size 30
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000
//const int prime = 999983;
const int INF = 0x3f3f3f3f;
const LL INFF = 0x3f3f;
const double pi = acos(-1.0);
//const double inf = 1e18;
const double eps = 1e-8;
const LL mod = 1e9+7;
const ull mx = 133333331;
/*****************************************************/
inline void RI(int &x) {
char c;
while((c=getchar())<'0' || c>'9');
x=c-'0';
while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
}
/*****************************************************/
struct Edge{
int u,v,next,c;
}edge[MAX*2];
int head[MAX];
int tot;//size是子树节点个数,son记录重链是哪个子节点
int top[MAX],son[MAX],size[MAX],dep[MAX];//top记录重链上的祖先
int tid[MAX],fa[MAX];//tid为先重链后其他边的新标号
int id[MAX];//新标号的点原来的标号
int label;
int num[MAX];
int n;
struct Node{
int num,numl,numr;
}p[MAX<<2];
int col[MAX<<2];
void init(){
mem(head,-1);
mem(son,-1);
label=0;
tot=0;
}
void add_edge(int a,int b,int c){
edge[tot]=(Edge){a,b,head[a],c};
head[a]=tot++;
}
//找重边
void dfs1(int u,int f,int d){
dep[u]=d;
fa[u]=f;
size[u]=1;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==f) continue;
dfs1(v,u,d+1);
size[u]+=size[v];
if(son[u]==-1||size[v]>size[son[u]]) son[u]=v;
}
}
//连接重链
void dfs2(int u,int ance){
top[u]=ance;
tid[u]=++label;
id[tid[u]]=u;
if(son[u]==-1) return;
dfs2(son[u],ance);
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
void pushup(int rt){
p[rt].num=p[lrt].num+p[rrt].num;
//cout<<p[rt].num<<endl;
p[rt].numl=p[lrt].numl;
p[rt].numr=p[rrt].numr;
if(p[lrt].numr==p[rrt].numl) p[rt].num--;
}
void pushdown(int rt){
if(col[rt]){
col[lrt]=col[rrt]=1;
col[rt]=0;
p[lrt].num=p[rrt].num=1;
p[lrt].numl=p[lrt].numr=p[rt].numl;
p[rrt].numl=p[rrt].numr=p[rt].numr;
}
}
void build(int l,int r,int rt){
col[rt]=0;
if(l==r){
p[rt].num=1;
p[rt].numl=p[rt].numr=num[id[l]];
//cout<<num[id[l]]<<endl;
return;
}
middle;
build(lson);
build(rson);
pushup(rt);
}
void update(int l,int r,int rt,int L,int R,int d){
if(L<=l&&r<=R){
col[rt]=1;
p[rt].num=1;
p[rt].numl=p[rt].numr=d;
return;
}
middle;
pushdown(rt);
if(L<=m) update(lson,L,R,d);
if(R>m) update(rson,L,R,d);
pushup(rt);
}
Node query(int l,int r,int rt,int L,int R){
if(L<=l&&r<=R) return p[rt];
middle;
pushdown(rt);
if(R<=m) return query(lson,L,R);
else if(L>m) return query(rson,L,R);
else{
Node p1=query(lson,L,R);
Node p2=query(rson,L,R);
Node ans=(Node){p1.num+p2.num,p1.numl,p2.numr};
if(p1.numr==p2.numl) ans.num--;
return ans;
}
}
void change(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,n,1,tid[top[x]],tid[x],z);
x=fa[top[x]];
}
if(x==y) return;
if(dep[x]<dep[y]) swap(x,y);
update(1,n,1,tid[son[y]],tid[x],z);
}
int que(int x,int y){
int ans=0;
int cx=-1,cy=-1;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]){
swap(x,y);
swap(cx,cy);
}
Node tmp=query(1,n,1,tid[top[x]],tid[x]);
ans+=tmp.num;
//cout<<x<<" "<<tmp.num<<endl;
//cout<<tmp.num<<endl;
if(tmp.numr==cx) ans--;
cx=tmp.numl;
x=fa[top[x]];
//cout<<cx<<" "<<ans<<endl;
}
if(x==y){
if(cx==cy) ans--;
return ans;
}
if(dep[x]<dep[y]){
swap(x,y);
swap(cx,cy);
}
Node tmp=query(1,n,1,tid[son[y]],tid[x]);
ans+=tmp.num;
//cout<<tmp.num<<endl;
if(cx==tmp.numr) ans--;
if(cy==tmp.numl) ans--;
return ans;
}
char s[10];
int main(){
int q;
while(cin>>n>>q){
init();
for(int i=1;i<n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add_edge(a,b,c);
add_edge(b,a,c);
}
dfs1(1,-1,0);
dfs2(1,1);
num[1]=0;
for(int i=0;i<tot;i+=2){
int x=edge[i].u;
int y=edge[i].v;
if(dep[x]<dep[y]) swap(x,y);
num[x]=edge[i].c;
}
build(1,n,1);
while(q--){
scanf("%s",s);
int a,b;
scanf("%d%d",&a,&b);
if(s[0]=='Q'){
if(a==b){
printf("0\n");
continue;
}
printf("%d\n",que(a,b));
}
else if(s[0]=='C'){
int c;
scanf("%d",&c);
change(a,b,c);
}
}
}
return 0;
}