边权化点权的处理:
向下处理,简单来说 ,就是根节点为 0 0 0,边权下放到与该节点链接的深度较大的节点。
另外,这个问题的查询也有点提示,即子路径的 x o r xor xor
所以我们需要向下异或,并且我们需要记录每个节点的异或值,然后将每个点的异或值用线段树进行维护。
按位建线段树,是因为 x o r xor xor的时候,懒标记是不能下传的,所以需要按位建树,然后统计在该位上有多少个1,多少个0, x o r xor xor的时候,只需要交换0和1的个数即可
这里采用的按位线段树的处理方式与上一篇文章略有不同,可以参考上一篇文章的建树方式:
建树:需要维护每一个下传 x o r xor xor的值
查询:直接沿链查询 x o r xor xor的异或位,需要注意一下每一个对应的是 2 n 2^n 2n,因此每次查询的时候都是要乘以 2 n 2^n 2n,然后才能相加
修改:这里有点意思,由于我们是使用下传的 x o r xor xor值,所以修改链 ( x , y ) (x,y) (x,y),就等于是在修改 x , y x,y x,y中深度较低的一个点的子树,另外由于 x o r xor xor的一个重要性质,即 a x o r b x o r b = a a {\,} {\,} xor {\,} {\,} b {\,} {\,} xor {\,} {\,} b {\,} {\,} = {\,} {\,} a axorbxorb=a,所以我们修改传入的值应该是 v a l x o r e d [ x ] val {\,} {\,}xor {\,} {\,}ed[x] valxored[x],其中 e d [ x ] ed[x] ed[x]代表的是深度较前的一个点(这里假定是x)的上边权(注意这个边权是没有经过 x o r xor xor的), v a l val val代表的是需要 x o r xor xor进去的值
#include <bits/stdc++.h>
#define inf 0x7fffffff
#define ll long long
#define int long long
//#define double long double
#define re register int
#define void inline void
#define eps 1e-8
//#define mod 1e9+7
#define ls(p) p<<1
#define rs(p) p<<1|1
#define pi acos(-1.0)
#define pb push_back
#define P pair < int , int >
#define mk make_pair
using namespace std;
const int mod=1e9+7;
const int M=1e8+5;
const int N=1e5+5;//?????????? 4e8
struct node
{
int ver,next,edge;
}e[N*2];
int tot,head[N];
int ed[N],b[N];
struct tree
{
int l,r,sum[15][2],add;
friend tree operator + (tree x,tree y)
{
tree z;
for(re i=0;i<=10;i++)
{
z.sum[i][0]=x.sum[i][0]+y.sum[i][0];
z.sum[i][1]=x.sum[i][1]+y.sum[i][1];
}
return z;
}
}t[N];
int n,m,son[N],fa[N],dep[N],dfn[N],num,top[N],w[N],sz[N],now[N];
void add(int x,int y,int z)
{
e[++tot].ver=y;
e[tot].edge=z;
e[tot].next=head[x];
head[x]=tot;
}
void addedge(int x,int y,int z)
{
add(x,y,z);add(y,x,z);
}
void dfs1(int x,int pre)
{
int maxn=-1;
dep[x]=dep[pre]+1;
fa[x]=pre;
sz[x]=1;
for(re i=head[x];i;i=e[i].next)
{
int y=e[i].ver;
int z=e[i].edge;
if(y==pre) continue;
b[y]=b[x]^z;
ed[y]=z;
dfs1(y,x);
sz[x]+=sz[y];
if(sz[y]>maxn)
{
maxn=sz[y];
son[x]=y;
}
}
}
void dfs2(int x,int pre)
{
dfn[x]=++num;
now[num]=x;
top[x]=pre;
w[num]=b[x];
if(!son[x]) return;
dfs2(son[x],pre);
for(re i=head[x];i;i=e[i].next)
{
int y=e[i].ver;
if(y==pre) continue;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
void push(int p)
{
int lazy=t[p].add;
int l=t[p].l,r=t[p].r;
t[p]=t[ls(p)]+t[rs(p)];
t[p].add=lazy;
t[p].l=l,t[p].r=r;
}
void bulid(int p,int l,int r)
{
t[p].l=l,t[p].r=r;
if(l==r)
{
for(re i=0;i<=10;i++) if((b[now[l]]>>i)&1) t[p].sum[i][1]=1;
else t[p].sum[i][0]=1;
return;
}
int mid=(l+r)>>1;
bulid(ls(p),l,mid);bulid(rs(p),mid+1,r);
push(p);
}
void spread(int p)
{
if(t[p].add)
{
for(re i=0;i<=10;i++) if((t[p].add>>i)&1)
{
swap(t[ls(p)].sum[i][0],t[ls(p)].sum[i][1]);
swap(t[rs(p)].sum[i][0],t[rs(p)].sum[i][1]);
}
t[ls(p)].add^=t[p].add;
t[rs(p)].add^=t[p].add;
t[p].add=0;
}
}
void change(int p,int l,int r,int z)
{
if(l<=t[p].l&&t[p].r<=r)
{
for(re i=0;i<=10;i++) if((z>>i)&1) swap(t[p].sum[i][0],t[p].sum[i][1]);
t[p].add^=z;
return;
}
spread(p);
int mid=(t[p].l+t[p].r)>>1;
if(l<=mid) change(ls(p),l,r,z);
if(mid<r) change(rs(p),l,r,z);
push(p);
}
tree ask(int p,int l,int r)
{
if(l<=t[p].l&&t[p].r<=r) return t[p];
spread(p);
int mid=(t[p].l+t[p].r)>>1;
tree ans;
for(re i=0;i<=12;i++) ans.sum[i][0]=ans.sum[i][1]=0;
if(l<=mid) ans=ans+ask(ls(p),l,r);
if(mid<r) ans=ans+ask(rs(p),l,r);
return ans;
}
void mson(int x,int y,int v)
{
if(dep[x]<dep[y]) swap(x,y);
change(1,dfn[x],dfn[x]+sz[x]-1,v^ed[x]);
ed[x]=v;
}
int qchain(int x,int y)
{
tree ans;
for(re i=0;i<=12;i++) ans.sum[i][0]=ans.sum[i][1]=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=ans+ask(1,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=ans+ask(1,dfn[x],dfn[y]);
int cnt=0;
for(re i=0;i<=10;i++) cnt+=(1<<i)*ans.sum[i][0]*ans.sum[i][1];
return cnt;
}
void solve()
{
cin>>n>>m;
for(re i=1;i<n;i++)
{
int x,y,z;
scanf("%lld%lld%lld",&x,&y,&z);
addedge(x,y,z);
}
dep[1]=1;dfs1(1,1);dfs2(1,1);bulid(1,1,n);
while(m--)
{
int op,x,y,z;
int ans=0;
scanf("%lld%lld%lld",&op,&x,&y);
if(op==1)
{
ans=qchain(x,y);
printf("%lld\n",ans);
}
else
{
scanf("%lld",&z);
mson(x,y,z);
}
}
}
signed main()
{
int T=1;
// cin>>T;
for(int index=1;index<=T;index++)
{
// printf("Case %d:\n",index);
solve();
// puts("");
}
return 0;
}
/*
1
6 5
0 0 0 122 499 8888
*/