class BinarySerachTree:
def __init__(self):
self.root=None
self.size=0
def length(self):
return self.size
def put(self,key,val):
if self.root:
self._put(key,val,self.root)
else :
self.root=TreeNode(key,val)
self.size+=1
def _put(self,key,val,currentNode):
if key currentNode.key:
if currentNode.hasRightChild():
self._put(key,val,currentNode.rightChild)
else :
currentNode.rightChild=TreeNode(key,val,parent=currentNode)
else :
currentNode.replaceNodeData(key,val,currentNode.leftChild,currentNode.rightChild)
def get(self,key):
if self.root:
res=self._get(key,self.root)
if res:
return res.payload
else :
return None
else :
return None
def _get(self,key,currentNode):
if key < currentNode.key :
if currentNode.hasLeftChild():
return self._get(key,currentNode.leftChild)
else :
return None
elif key > currentNode.key :
if currentNode.hasRightChild():
return self._get(key,currentNode.rightChild)
else :
return None
else :
return currentNode
def gethead(self,currentNode):
if currentNode.leftChild and currentNode.rightChild:
return 1+max(self.gethead(currentNode.leftChild),self.gethead(currentNode.rightChild))
elif currentNode.leftChild is None and currentNode.rightChild is not None :
return 1+self.gethead(currentNode.rightChild)
elif currentNode.leftChild is not None and currentNode.rightChild is None :
return 1+self.gethead(currentNode.leftChild)
else :
return 0
def __getitem__(self,key):
return self.get(key)
def __setitem__(self,k,v):
self.put(k,v)
def __len__(self): #为了可以使用len()来查看长度
return self.size
def __iter__(self): #为了可以使用迭代算法
return self.root.__iter__()
def __contains__(self,key):
if self.get(key):
return True
else :
return False
def delete(self,key):
if self.size > 1:
nodeToRemove=self._get(key,self.root)
if nodeToRemove:
self.remove(nodeToRemove)
self.size-=1
else :
raise KeyError('Error, key not in tree')
elif self.size==1 and self.root.key==key:
self.root=None
self.size -=1
else :
raise KeyError('Error, key not in tree')
def __delitem__(self,key):
self.delete(key)
def remove(self,currentNode):
if currentNode.isLeaf():
if currentNode.isLeftChild():
currentNode.parent.leftChild=None
else :
currentNode.parent.rightChild=None
elif currentNode.hasBothChildern():
succ=currentNode.findSuccessor()
succ.spliceOut()
currentNode.key=succ.key
currentNode.payload=succ.payload
else :
if currentNode.hasRightChild():
if currentNode.isLeftChild():
currentNode.parent.leftChild=currentNode.rightChild
currentNode.rightChild.parent=currentNode.parent
elif currentNode.isRightChild():
currentNode.parent.rightChild=currentNode.rightChild
currentNode.rightChild.parent=currentNode.parent
else :
self.root=currentNode.rightChild
else :
if currentNode.isLeftChild():
currentNode.parent.leftChild=currentNode.leftChild
currentNode.leftChild.parent=currentNode.parent
elif currentNode.isRightChild():
currentNode.parent.rightChild=currentNode.leftChild
currentNode.leftChild.parent=currentNode.parent
else :
self.root=currentNode.rightChild
class AVLtree(BinarySerachTree):
def _put(self,key,val,currentNode):
if key < currentNode.key :
if currentNode.hasLeftChild():
self._put(key,val,currentNode.leftChild)
else :
currentNode.leftChild=TreeNode(key,val,parent=currentNode)
self.updateBalance(currentNode.leftChild)
elif key > currentNode.key:
if currentNode.hasRightChild():
self._put(key,val,currentNode.rightChild)
else :
currentNode.rightChild=TreeNode(key,val,parent=currentNode)
self.updateBalance(currentNode.rightChild)
else :
currentNode.replaceNodeData(key,val,currentNode.leftChild,currentNode.rightChild)
def updateBalance(self,node):
if node.balanceFactor >1 or node.balanceFactor <-1 :
self.rebalance(node)
return
if node.parent != None :
if node.isLeftChild():
node.parent.balanceFactor +=1
elif node.isRightChild():
node.parent.balanceFactor -=1
if node.parent.balanceFactor !=0:
self.updateBalance(node.parent)
def rotateLeft(self,rotRoot):
newRoot=rotRoot.rightChild
rotRoot.rightChild=newRoot.leftChild
if newRoot.leftChild != None :
newRoot.leftChild.parent = rotRoot
newRoot.parent=rotRoot.parent
if rotRoot.isRoot():
self.root=newRoot
else :
if rotRoot.isLeftChild():
rotRoot.parent.leftChild=newRoot
else :
rotRoot.parent.rightChild=newRoot
newRoot.leftChild=rotRoot
rotRoot.parent=newRoot
rotRoot.balanceFactor=rotRoot.balanceFactor +1-min(newRoot.balanceFactor,0)
newRoot.balanceFactor=newRoot.balanceFactor+1+max(rotRoot.balanceFactor,0)
def rotateRight(self,rotRoot):
newRoot=rotRoot.leftChild
rotRoot.leftChild=newRoot.rightChild
if newRoot.rightChild != None:
newRoot.rightChild.parent=rotRoot
newRoot.parent=rotRoot.parent
if rotRoot.isRoot():
self.root=newRoot
else :
if rotRoot.isLeftChild():
rotRoot.parent.leftChild=newRoot
else :
rotRoot.parent.rightChild=newRoot
newRoot.rightChild=rotRoot
rotRoot.parent=newRoot
rotRoot.balanceFactor=rotRoot.balanceFactor -1-max(newRoot.balanceFactor,0)
newRoot.balanceFactor=newRoot.balanceFactor-1-min(rotRoot.balanceFactor,0)
def rebalance(self,node):
if node.balanceFactor <0:
if node.rightChild.balanceFactor >0:
self.rotateRight(node.rightChild)
self.rotateLeft(node)
else:
self.rotateLeft(node)
elif node.balanceFactor >0:
if node.leftChild.balanceFactor <0:
self.rotateLeft(node.leftChild)
self.rotateRight(node)
else :
self.rotateRight(node)
def remove(self,currentNode):
if currentNode.isLeaf():
if currentNode.isLeftChild():
currentNode.parent.leftChild=None
currentNode.parent.balanceFactor-=1
self.updateBalance(currentNode.parent)
else :
currentNode.parent.rightChild=None
currentNode.parent.balanceFactor +=1
self.updateBalance(currentNode.parent)
elif currentNode.hasBothChildern():
succ=currentNode.findSuccessor()
succ.spliceOut()
currentNode.key=succ.key
currentNode.payload=succ.payload
succ.parent.balanceFactor -=1
self.updateBalance(succ.parent)
else :
if currentNode.hasRightChild():
if currentNode.isLeftChild():
currentNode.parent.leftChild=currentNode.rightChild
currentNode.rightChild.parent=currentNode.parent
currentNode.parent.balanceFactor-=1
self.updateBalance(currentNode.parent)
elif currentNode.isRightChild():
currentNode.parent.rightChild=currentNode.rightChild
currentNode.rightChild.parent=currentNode.parent
currentNode.parent.balanceFactor +=1
self.updateBalance(currentNode.parent)
else :
self.root=currentNode.rightChild
else :
if currentNode.isLeftChild():
currentNode.parent.leftChild=currentNode.leftChild
currentNode.leftChild.parent=currentNode.parent
currentNode.parent.balanceFactor-=1
self.updateBalance(currentNode.parent)
elif currentNode.isRightChild():
currentNode.parent.rightChild=currentNode.leftChild
currentNode.leftChild.parent=currentNode.parent
currentNode.parent.balanceFactor +=1
self.updateBalance(currentNode.parent)
else :
self.root=currentNode.rightChild
class TreeNode:
def __init__(self,key,val,left=None,right=None,parent=None):
self.key=key
self.payload=val
self.leftChild=left
self.rightChild=right
self.parent=parent
self.balanceFactor=0
def getbanlaceFactor(self):
return self.balanceFactor
def hasLeftChild(self):
return self.leftChild
def hasRightChild(self):
return self.rightChild
def isLeftChild(self):
return self.parent and self.parent.leftChild ==self
def isRightChild(self):
return self.parent and self.parent.rightChild==self
def isRoot(self):
return not self.parent
def isLeaf(self):
return not (self.rightChild and self.leftChild) #和原书有点不同
def hasAnyChildern(self):
return self.leftChild or self.rightChild #有就返回一个对象
def hasBothChildern(self):
return self.leftChild and self.leftChild
def replaceNodeData(self,key,value,lc,rc):
self.key=key
self.payload=value
self.leftChild=lc
self.rightChild=rc
if self.hasLeftChild():
self.leftChild.parent=self
if self.hasRightChild():
self.rightChild.parent =self
def findSuccessor(self):
succ=None
if self.hasRightChild():
succ=self.rightChild.findMin()
else :
if self.parent:
if self.isLeftChild():
succ=self.parent
else :
self.parent.rightChild=None
succ=self.parent.findSuccessor()
self.parent.rightChild=self
return succ
def findMin(self):
if self.leftChild :
return self.leftChild.findMin()
else :
return self
def spliceOut(self):
if self.isLeaf():
if self.isLeftChild():
self.parent.leftChild=None
else :
self.parent.rightChild=None
elif self.hasAnyChildern():
if self.hasLeftChild():
if self.isLeftChild():
self.parent.leftChild=self.leftChild
self.leftChild.parent=self.parent
elif self.isRightChild():
self.parent.rightChild=self.leftChild
self.leftChild.parent=self.parent
else :
pass
else :
if self.isLeftChild():
self.parent.leftChild=self.rightChild
self.rightChild.parent=self.parent
elif self.isRightChild():
self.parent.rightChild=self.rightChild
self.rightChild.parent=self.parent
else :
pass
def __iter__(self): ####需要理解一下
if self:
if self.hasLeftChild():
for elem in self.leftChild :
yield elem
yield self.key
if self.hasRightChild():
for elem in self.rightChild:
yield elem
a=AVLtree()
a.put(12,'a')
a.put(7,'b')
a.put(10,'d')
a.put(8,'c')
del a[7]
print(a.root.key,a.length(),a.gethead(a.root),a.root.getbanlaceFactor())
>>10 3 1 0
代码比较复杂,整体是按照《python 数据结构与算法 第二版》书上的内容写的,但是书上有一些功能的缺失,按照自己的理解加上去了,也许会有bug。



