根据bilibili上一个up主的视频写的代码
以此文记录备忘
先定义Node类,用于构造需要剪枝的二叉树
class Node:
"""
:param value 该节点的值,默认为0
:param is_max 该节点是否在Max层
child 是该节点的子节点,是一个Node类的数组
"""
def __init__(self, value=0, is_max=True):
self.value = value
self.is_max = is_max
self.child = None
def setChildWithValue(self, childs):
temp = []
is_Max = not self.is_max
for c in childs:
temp.append(Node(value=c, is_max=is_Max))
self.child = temp
def setChildWithNode(self, childs):
if None is self.child:
self.child = childs
return
for c in childs:
self.child.append(c)
定义一个初始函数用于构造待剪枝的树(也就是把up主的那张图复现出来),可以换几组数据多试试,如果有问题请告知(最好带上数据),毕竟是个人的练习:
def init():
node1 = Node(is_max=False)
node1.setChildWithValue([60, 63])
node2 = Node(is_max=False)
node2.setChildWithValue([15, 58])
node3 = Node(is_max=True)
node3.setChildWithNode([node1, node2])
node1 = Node(is_max=False)
node1.setChildWithValue([81, 74])
node2 = Node(is_max=False)
node2.setChildWithValue([88, 15, 27])
node4 = Node(is_max=True)
node4.setChildWithNode([node1, node2])
node5 = Node(is_max=False)
node5.setChildWithNode([node3, node4])
node1 = Node(is_max=False)
node1.setChildWithValue([20, 92])
node2 = Node(is_max=False)
node2.setChildWithValue([9, 62])
node3 = Node(is_max=True)
node3.setChildWithNode([node1, node2])
node1 = Node(is_max=False)
node1.setChildWithValue([82, 92])
node2 = Node(is_max=False)
node2.setChildWithValue([54, 17])
node4 = Node(is_max=True)
node4.setChildWithNode([node1, node2])
node6 = Node(is_max=False)
node6.setChildWithNode([node3, node4])
head = Node(is_max=True)
head.setChildWithNode([node5, node6])
return head
Minmax :
def mini_max(node):
if node.child is None:
return node.value
if not node.is_max:
best_value = float('inf')
for c in node.child:
best_value = min(best_value, mini_max(c))
else:
best_value = -float('inf')
for c in node.child:
best_value = max(best_value, mini_max(c))
return best_value
Minmax的简化版本negativeMax:
def negative_max(node):
if node.child is None:
return -node.value
best_value = -float('inf')
for c in node.child:
best_value = -max(best_value, negative_max(c))
return best_value
α − β \alpha-\beta α−β剪枝
def alpha_beta(node, alpha, beta):
# alpha表示己方,要提高到最大利益,beta表示敌方,要降到最小利益
# alpha大于beta的时候就可以开始剪枝了,因为己方收益已经可以保证大于敌方收益了
# min层修改beta(最小化敌方收益),max层修改alpha(最大化己方收益)
if node.child is None:
return node.value
if not node.is_max:
# 该层为min层,要最小化敌方收益,所以best_value取越小越好(初始一个无穷大)
best_value = float('inf')
for c in node.child:
value = alpha_beta(c, alpha, beta)
best_value = min(best_value, value)
beta = min(beta, best_value)
if alpha >= beta:
break
else:
# 该层为max层,要最大化己方收益,所以best_value取越大越好(初始一个无穷小)
best_value = -float('inf')
for c in node.child:
value = alpha_beta(c, alpha, beta)
best_value = max(best_value, value)
alpha = max(alpha, best_value)
if alpha >= beta:
break
return best_value
测试:
if __name__ == '__main__':
head = init()
# print(mini_max(head))
# print(-negative_max(head))
print(alpha_beta(head, -float('inf'), float('inf')))