问题描述
给一颗含有n个节点的有根树,根节点为1,编号为i的点有点权ai ( i ∈[1,n])。现在有两种操作,格式如下:
1 x y : 该操作将点x的点权改为y 。
2 x : 该操作表示查询以结点x 为根的子树内的所有点的点权的异或和。
现有长度为m 的操作序列,请对于每个第二类操作给出正确的结果 。
输入格式
输入的第一行包含两给正整数n,m,用一个空格分隔。
第二行包含 n 个整数a1,a2,...,an,相邻整数之间使用一个空格分隔。
接下来 n − 1 行,每行包含两个正整数 ui,vi ,表示结点 ui 和 v 之间有一条边。
接下来 m 行,每行包含一个操作
输出格式
输出若干行,每行对应一个查询操作的答案。
样例输入
4 4
1 2 3 4
1 2
1 3
2 4
2 1
1 1 0
2 1
2 2
样例输出
4
5
6
测评用例规模与约定
对于30%的测评用例, n,n<=1000;
对于所有测评用例, 1<=n,m<=100000, 0<=ai, y<=100000, 1<= ui,vi,x<=n。
概念引入
首先要了解两个概念,一个是子树,一个是异或操作。
1.题目中只是说是ui和vi之间存在一条边,没有明确指出哪个是父节点,哪个是子节点,这就产生了歧义,在子树划分的时候。由问题描述中的根节点为1,可以推断出,在所有的边中,编号较小的节点是父节点,编号较大的是子节点。(这里有点不妥,应该设置visit数组,根据先后进入顺序确定为父节点还是子节点,懒得改)
2.异或操作的性质。异或逻辑的关系是:当AB不同时,输出P=1;当AB相同时,输出P=0。在python中可以通过^操作符实现。异或操作的几个性质如下:
A^A=0;
A^0=A;
A^B^A=B^A^A=B^0=B;
思路讲解
1. 定义两个字典,child_d,f_d,分别记录子节点和父节点的信息;定义一个数组sum_,存储编号为i的节点的子树的异或和。
n,m = map(int,input().split())
num = list(map(int,input().split()))
num.insert(0,0)
child_d = dict()
f_d = dict()
sum_ = [i for i in num]
2.dfs搜索child_d,初始化异或和数组sum_。
def dfs(n):
if n not in child_d:
return num[n]
for i in child_d[n]:
sum_[n] ^=dfs(i)
return sum_[n]
_ = dfs(1)
3.对于输入的第一类操作,即返回sum_[i]即可。
4.对于第二类的修改操作,维护sum_是整道题的关键。
因为异或和数组储存的是子树的所有数一起的异或和,所以如果修改节点v的值,那么v和v的所有祖先节点(即父节点和父节点的父节点等等)的异或和数组sum_的值都会改变。
之前维护的记录父节点的字典就用得到了,修改一次节点的值之后,要循环更新所有祖先节点的异或和的值。
由之前提到的性质
A^B^A=B^A^A=B^0=B;
可以得到,
将节点v的值从n1改到n2,节点的异或和s只需要与n1,n2一起进行异或操作即可完成更新。
如: 当前节点的权值为4,异或和为
s = 4^3^2^1,
进行第二类操作将节点的值改为3
修改后的节点的异或和为
s2 = 3^3^2^1
将s与n1,n2一起进行异或操作,得
s3 = s^4^3 = 4^3^2^1^4^3 = 4^4^3^3^2^1 = 3^3^2^1 = s2
for i in range(m):
s = list(map(int,input().split()))
if s[0]==1:
temp =s[1]
val = s[2]^num[temp]
num[temp] = s[2]
sum_[temp]^= val
while temp in f_d:
temp = f_d[temp]
sum_[temp]^= val
else:
s2.append(sum_[s[1]])
完整代码
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 1 11:49:51 2023
@author: 犹豫
"""
n,m = map(int,input().split())
num = list(map(int,input().split()))
num.insert(0,0)
child_d = dict()
f_d = dict()
sum_ = [i for i in num]
for i in range(n-1):
v,u = map(int,input().split())
if v>u:
v,u =u,v
if v in child_d:
child_d[v].append(u)
else:
child_d[v]=[u]
f_d[u]=v
def dfs(n):
if n not in child_d:
return num[n]
for i in child_d[n]:
sum_[n] ^=dfs(i)
return sum_[n]
_ = dfs(1)
s2=[]
for i in range(m):
s = list(map(int,input().split()))
if s[0]==1:
temp =s[1]
val = s[2]^num[temp]
num[temp] = s[2]
sum_[temp]^= val
while temp in f_d:
temp = f_d[temp]
sum_[temp]^= val
else:
s2.append(sum_[s[1]])
for i in s2:
print(i)