给初中的讲题,顺便打下练手题,然而发现我已经是老年选手了,平时口胡太多打题会死的。。。竟然把dep打成dfn,本来是区间修改单点查询的线段树,为了常数改成单点修改区间询问,然而我打了单点询问。
题目
我们有一个树,大小为n。考虑树上的一条路径,如果一个边的两个点都在这路径上,我们称这个边属于这个路径,如果一个边有且只有一个点在这路径上,我们称这个边与这个路径相邻。现在每个边要么是黑色的要么是白色的,一开始所有边都是白色的。
我们有3个操作,将某路径反色,将与某路径相邻的所有边反色,求一个路径上黑边的总数。
1≤n,q≤105
Solution
为了照顾初中选手,我讲了树链剖分做法,很jieba的就是操作2,因为好像很难维护,但是我们要发现一个trick:我们路径附近的边中,除了重链上的边,其他都是一条链的链顶上面的边,那么这样我们就可以在询问时做到链顶时再特殊处理一下,具体的就是要用了两颗线段树,Sgt1用来维护每条边的颜色,Sgt2用来维护每个点的信息(除了他的重儿子以外的边是否都被反色),然后询问时就在链顶时查询一下链顶上面的父亲的Sgt2中的信息。
其实好像还有LCT作法。
Code
#include<iostream>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<algorithm>
#include<set>
#include<map>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long LL;
typedef double db;
int get(){
char ch;
while(ch=getchar(),(ch<'0'||ch>'9')&&ch!='-');
if (ch=='-'){
int s=0;
while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';
return -s;
}
int s=ch-'0';
while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';
return s;
}
const int N = 100010;
struct edge{
int x,nxt;
}e[N*2];
int h[N],tot;
struct node{
int l,r,ad,tot;
}tree1[N*2];
int tree2[N];
int n,q,tot1,tot2;
bool vis[N];
int dfn[N],k,be[N],m,ma[N],fa[N],dep[N],tp[N],s[N];
void inse(int x,int y){
e[++tot].x=y;
e[tot].nxt=h[x];
h[x]=tot;
}
void dfs1(int x){
vis[x]=1;
s[x]=1;
for(int p=h[x];p;p=e[p].nxt)
if (!vis[e[p].x]){
fa[e[p].x]=x;
dep[e[p].x]=dep[x]+1;
dfs1(e[p].x);
s[x]+=s[e[p].x];
if (!ma[x]||s[e[p].x]>s[ma[x]])ma[x]=e[p].x;
}
}
void dfs2(int x){
dfn[x]=++k;
if (!ma[x]){tp[be[x]=++m]=x;return;}
dfs2(ma[x]);
tp[be[x]=be[ma[x]]]=x;
for(int p=h[x];p;p=e[p].nxt)
if (fa[e[p].x]==x&&e[p].x!=ma[x])dfs2(e[p].x);
}
void build1(int &now,int l,int r){
now=++tot1;
if (l==r)return;
int mid=(l+r)/2;
build1(tree1[now].l,l,mid);
build1(tree1[now].r,mid+1,r);
}
int lca(int x,int y){
while(be[x]!=be[y]){
if (dfn[tp[be[x]]]<dfn[tp[be[y]]])swap(x,y);
x=fa[tp[be[x]]];
}
return dfn[x]<dfn[y]?x:y;
}
int jump(int x,int y){
while(dep[x]>dep[y]+1){
if (x==tp[be[x]])x=fa[x];
else{
if (be[x]==be[y])x=ma[y];
else x=tp[be[x]];
}
}
return x;
}
void down1(int now,int l,int r){
if (!tree1[now].ad)return;
int mid=(l+r)/2;
int s=tree1[now].l;
tree1[s].ad^=1;
tree1[s].tot=mid-l+1-tree1[s].tot;
s=tree1[now].r;
tree1[s].ad^=1;
tree1[s].tot=r-mid-tree1[s].tot;
tree1[now].ad=0;
}
void change1(int now,int l,int r,int x,int y){
if (x<=l&&r<=y){
tree1[now].ad^=1;
tree1[now].tot=r-l+1-tree1[now].tot;
return;
}
int mid=(l+r)/2;
down1(now,l,r);
if (x<=mid)change1(tree1[now].l,l,mid,x,y);
if (y>mid)change1(tree1[now].r,mid+1,r,x,y);
tree1[now].tot=tree1[tree1[now].l].tot+tree1[tree1[now].r].tot;
}
int getv1(int now,int l,int r,int x,int y){
if (x<=l&&r<=y)return tree1[now].tot;
int mid=(l+r)/2,ans=0;
down1(now,l,r);
if(x<=mid)ans=getv1(tree1[now].l,l,mid,x,y);
if(y>mid)ans=ans+getv1(tree1[now].r,mid+1,r,x,y);
return ans;
}
void change2(int x){
while(x<=n){
tree2[x]^=1;
x+=x&-x;
}
}
int getv2(int x){
int ans=0;
while(x){
ans^=tree2[x];
x-=x&-x;
}
return ans;
}
void operation1(int x){
while(x){
change1(1,1,n,dfn[tp[be[x]]],dfn[x]);
x=fa[tp[be[x]]];
}
}
void operation2(int x,int y){
if (ma[x])change1(1,1,n,dfn[ma[x]],dfn[ma[x]]);
while(dfn[x]>=dfn[y]){
if (be[x]!=be[y]){
change2(dfn[tp[be[x]]]);
change2(dfn[x]+1);
change1(1,1,n,dfn[tp[be[x]]],dfn[tp[be[x]]]);
change1(1,1,n,dfn[fa[tp[be[x]]]]+1,dfn[fa[tp[be[x]]]]+1);
x=fa[tp[be[x]]];
}
else{
change2(dfn[y]);
change2(dfn[x]+1);
x=fa[y];
}
}
}
int getans(int x){
int ans=0;
while(x){
if (x!=tp[be[x]])ans=ans+getv1(1,1,n,dfn[tp[be[x]]]+1,dfn[x]);
x=tp[be[x]];
if (x==1)ans=ans+getv1(1,1,n,1,1);
else ans=ans+(getv2(dfn[fa[x]])^getv1(1,1,n,dfn[x],dfn[x]));
x=fa[x];
}
return ans;
}
void work(){
q=get();
fo(i,1,q){
int ty=get(),x=get(),y=get();
if (ty==1){
operation1(x);
operation1(y);
}
if (ty==2){
if (x==y){
operation2(x,x);
change1(1,1,n,dfn[x],dfn[x]);
continue;
}
int z=lca(x,y);
if (dfn[x]>dfn[y])swap(x,y);
if (x==z){
operation2(y,z);
change1(1,1,n,dfn[z],dfn[z]);
continue;
}
operation2(x,z);
int u=jump(y,z);
operation2(y,u);
change1(1,1,n,dfn[z],dfn[z]);
change1(1,1,n,dfn[u],dfn[u]);
}
if (ty==3){
int z=lca(x,y);
printf("%d\n",getans(x)+getans(y)-getans(z)*2);
}
}
}
int main(){
freopen("data.in","r",stdin);
freopen("data.out","w",stdout);
n=get();
fo(i,2,n){
int x=get(),y=get();
inse(x,y);
inse(y,x);
}
dfs1(1);
dfs2(1);
int root;
build1(root,1,n);
work();
fclose(stdin);
fclose(stdout);
return 0;
}