class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n) # 定义线段树数组
self.lazy = [0] * (4 * self.n) # 定义延迟更新数组
self.build(arr, 1, 0, self.n - 1)
def build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.build(arr, 2 * node, start, mid)
self.build(arr, 2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
if start <= idx <= mid:
self.update(2 * node, start, mid, idx, val)
else:
self.update(2 * node + 1, mid + 1, end, idx, val)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def query(self, node, start, end, left, right):
if start > right or end < left:
return 0
if self.lazy[node] != 0:
self.tree[node] += (end - start + 1) * self.lazy[node]
if start != end:
self.lazy[2 * node] += self.lazy[node]
self.lazy[2 * node + 1] += self.lazy[node]
self.lazy[node] = 0
if left <= start and end <= right:
return self.tree[node]
mid = (start + end) // 2
return self.query(2 * node, start, mid, left, right) + self.query(2 * node + 1, mid + 1, end, left, right)
def range_update(self, node, start, end, left, right, val):
if self.lazy[node] != 0:
self.tree[node] += (end - start + 1) * self.lazy[node]
if start != end:
self.lazy[2 * node] += self.lazy[node]
self.lazy[2 * node + 1] += self.lazy[node]
self.lazy[node] = 0
if start > right or end < left:
return
if left <= start and end <= right:
self.tree[node] += (end - start + 1) * val
if start != end:
self.lazy[2 * node] += val
self.lazy[2 * node + 1] += val
return
mid = (start + end) // 2
self.range_update(2 * node, start, mid, left, right, val)
self.range_update(2 * node + 1, mid + 1, end, left, right, val)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def update_value_at_index(self, idx, val):
self.update(1, 0, self.n - 1, idx, val)
def query_range_sum(self, left, right):
return self.query(1, 0, self.n - 1, left, right)
def range_update_values(self, left, right, val):
self.range_update(1, 0, self.n - 1, left, right, val)
# 使用示例:
arr = [1, 3, 5, 7, 9, 11]
st = SegmentTree(arr)
print(st.query_range_sum(1, 4)) # 输出:24 (3 + 5 + 7 + 9)
st.update_value_at_index(2, 6) # 将第3个元素由5更新为6
print(st.query_range_sum(1, 4)) # 输出:27 (3 + 6 + 7 + 9)
st.range_update_values(0, 2, 2) # 将区间[0, 2]内的元素都增加2
print(st.query_range_sum(0, 5)) # 输出:33 (5 + 8 + 9 + 9 + 11)