class TreeNode: def __init__(self, key, val, left=None, right=None, parent=None): self.key = key self.payload = val self.leftChild = left self.rightChild = right self.parent = parent self.balanceFator = 0 # may has bug def hasLeftChild(self): return self.leftChild def hasRightChild(self): return self.rightChild def isLeftChild(self): return self.parent and self.parent.leftChild == self def isRightChild(self): return self.parent and self.parent.rightChild == self def isRoot(self): return not self.parent def isLeaf(self): return not (self.rightChild or self.leftChild) def hasAnyChild(self): return self.rightChild or self.leftChild def hasBothChildren(self): return self.rightChild and self.leftChild def replaceNodeData(self, key, value, lc, rc): self.key = key self.payload = value self.leftChild = lc self.rightChild = rc if self.hasLeftChild(): self.leftChild.parent = self if self.hasRightChild(): self.rightChild.parent = self def __iter__(self): # 中序遍历 if self: if self.hasLeftChild(): for elem in self.leftChild: yield elem yield self.key if self.hasRightChild(): for elem in self.rightChild: yield elem def findSuccessor(self): succ = None if self.hasRightChild(): succ = self.rightChild.findMin() else: if self.parent: if self.isLeafChild(): succ = self.parent else: self.parent.rightChild = None # 将自己删除 succ = self.parent.findSuccessor() self.parent.rightChild = self # 将自己恢复 return succ def findMin(self): current = self while current.hasLeftChild(): current = current.leftChild return current def spliceOut(self): if self.isLeaf(): if self.isLeftChild(): self.parent.leftChild = None else: self.parent.rightChild = None elif self.hasAnyChild(): if self.hasLeftChild(): if self.isLeftChild(): self.parent.leftChild = self.leftChild else: self.parent.rightChild = self.leftChild self.leftChild.parent = self.parent else: if self.isLeftChild(): self.parent.leftChild = self.rightChild else: self.parent.rightChild = self.rightChild self.rightChild.parent = self.parent class BinarySearchTree: def __init__(self): self.root = None self.size = 0 def length(self): return self.size def __len__(self): return self.size def __iter__(self): return self.root.__iter__() def put(self, key, val): if self.root: self._put(key, val, self.root) else: self.root = TreeNode(key, val) self.size += 1 def _put(self, key, value, currentNode): if key < currentNode.key: if currentNode.hasLeftChild(): self._put(key, value, currentNode.leftChild) else: currentNode.leftChild = TreeNode(key, value, parent=currentNode) self.updataBalance(currentNode.leftChild) else: if currentNode.hasRightChild(): self._put(key, value, currentNode.rightChild) else: currentNode.rightChild = TreeNode(key, value, parent=currentNode) self.updataBalance(currentNode.rightChild) def updataBalance(self, node): if node.balanceFactor > 1 or node.balanceFactor < -1: self.rebalance(node) return if node.parent != None: if node.isLeftChild(): node.parent.balanceFactor += 1 elif node.isRightChild(): node.parent.balanceFactor -= 1 if node.parent.balanceFactor != 0: self.updataBalance(node.parent) def rebalance(self, node): if node.balanceFator < 0: # 右重做rotateleft if node.rightChild.balanceFactor > 0: # Do R Rotation self.rotateRight(node.rightChild) self.rotateLeft(node) else: # single left self.rotateLeft(node) elif node.balanceFator > 0: if node.leftChild.balanceFactor < 0: # Do L Rotation self.rotateLeft(node.rightChild) self.rotateRight(node) else: # single right self.rotateRight(node) def rotateLeft(self, rotRoot): newRoot = rotRoot.rightChild rotRoot.rightChild = newRoot.leftChild if newRoot.leftChild != None: newRoot.leftChild.parent = rotRoot newRoot.parent = rotRoot.parent if rotRoot.isRoot(): self.root = newRoot else: if rotRoot.isLeftChild(): rotRoot.parent.leftChild = newRoot else: rotRoot.parent.rightChild = newRoot newRoot.leftChild = rotRoot rotRoot.parent = newRoot rotRoot.balanceFactor = rotRoot.balanceFactor + 1 - min(newRoot.balanceFactor, 0) newRoot.balanceFactor = newRoot.balanceFactor + 1 + max(rotRoot.balanceFactor, 0) def rotateRight(self, rotRoot): newRoot = rotRoot.leftChild rotRoot.leftChild = newRoot.rightChild if newRoot.rightChild != None: newRoot.rightChild.parent = rotRoot newRoot.parent = rotRoot.parent if rotRoot.isRoot(): self.root = newRoot else: if rotRoot.isLeftChild(): rotRoot.parent.leftChild = newRoot else: rotRoot.parent.rightChild = newRoot newRoot.rightChild = rotRoot rotRoot.parent = newRoot rotRoot.balanceFactor = rotRoot.balanceFactor - 1 - max(newRoot.balanceFactor, 0) newRoot.balanceFactor = newRoot.balanceFactor - 1 + min(rotRoot.balanceFactor, 0) def __setitem__(self, key, value): self.put(key, value) def get(self, key): if self.root: res = self._get(key, self.root) if res: return res.payload else: return None else: return None def _get(self, key, currentNode): if not currentNode: return None elif currentNode.key == key: return currentNode elif key < currentNode.key: return self._get(key, currentNode.leftChild) else: return self._get(key, currentNode.rigthChild) def __getitem__(self, key): # mytree[10] return self.get(key) def __contains__(self, key): # key in map if self._get(key, self.root): return True else: return False def delete(self, key): if self.size > 1: nodeToRemove = self._get(key, self.root) if nodeToRemove: self.remove(nodeToRemove) else: raise KeyError('Error,key not in tree') elif self.size == 1 and self.root.key == key: self.root = None self.size = self.size - 1 else: raise KeyError('Error,key not in tree') def __delitem__(self, key): self.delete(key) def remove(self, currentNode): if currentNode.isLeaf(): if currentNode == currentNode.parent.leftChild: currentNode.parent.leftChild = None else: currentNode.parent.rigthChild = None elif currentNode.hasBothChildren(): # interior succ = currentNode.findSuccessor() succ.spliceOut() currentNode.key = succ.key currentNode.payload = succ.payload else: # this node has one child if currentNode.hasLeftChild(): if currentNode.isLeftChild(): currentNode.leftChild.parent = currentNode.parent currentNode.parent.leftChild = currentNode.leftChild elif currentNode.isRightChild(): currentNode.leftChild.parent = currentNode.parent currentNode.parent.rightChild = currentNode.leftChild else: currentNode.replaceNodeData(currentNode.leftChild.key, currentNode.leftChild.payload, \ currentNode.leftChild.leftChild, currentNode.leftChild.rightChild) else: if currentNode.isLeftChild(): currentNode.rightChild.parent = currentNode.parent currentNode.parent.leftChild = currentNode.rightChild elif currentNode.isRightChild(): currentNode.rightChild.parent = currentNode.parent currentNode.parent.rightChild = currentNode.rightChild else: currentNode.replaceNodeData(currentNode.rightChild.key, currentNode.rightChild.payload, \ currentNode.rightChild.leftChild, currentNode.rightChild.rightChild)
北京大学2020公开课 AVL-Python实现代码
最新推荐文章于 2022-04-27 15:54:54 发布