Python数据分析(02) graphviz绘制KD二叉查找树

本文介绍使用Python和graphviz绘制K近邻(KNN)算法中的二叉树及平面模拟图的方法。涵盖graphviz的基本语法、安装步骤以及Python绘制流程。适用于希望直观理解KNN算法流程的学习者。
摘要由CSDN通过智能技术生成

前言

        我在另一篇博文机器学习(3) K近邻算法(KNN)介绍及C++实现中介绍了K近邻算法及KD树的实现方法,博文编写过程中需要显式绘制二叉树将其表示出来。初始方法是使用C++生成KD树,并根据graphviz的dot语言逐行编写KD树。通过查阅相关文献,发现使用Python绘制KD树的过程并不繁琐,于是本文介绍使用Python绘制graphviz二叉查找树的图形。另外,介绍博文中是绘制二维平面上KD树、收敛过程如何表示的图形化绘制方法

graphviz简介

        graphviz是一种便于绘制流程图、树形结构等的图形可视化软件。掌握基础的脚本语言就可以轻松绘制属于自己的流程图、二叉树图等内容。

graphviz安装

        安装graphviz流程:打开graphviz下载链接,依据网页提示选择属于自己的平台安装包。我安装的是windows10下的stable_windows_10_msbuild_Release_Win32_graphviz-2.46.0-win32.zip。下载完成后解压到C:/Software/graphviz等自己习惯的路径下,将C:/software/graphviz/bin加入到系统环境变量中,重启电脑以配置graphviz环境变量。使用Python绘制graphviz流程图,需要在安装python3环境后,在命令行pip install graphviz即可。
        安装python3环境流程:以windows10为例,打开Python3-Windows-下载地址,下载最新版本安装包或适合自己的版本安装包,安装到本地过程中记得配置环境变量Path,不再赘述。

graphviz语法

        我将着重讲述使用graphviz绘制二叉树涉及到的语法知识,更详细的语法知识参见graphviz官网说明文档。我将讲解两方面知识,第一是使用命令行编译dot文件,第二是使用python直接生成.gv文件。
        首先介绍命令行下如何使用graphviz语法编写一个二叉树。在自己的路径下生成一个demo.dot文档,文档中内容如下:

// demo.dot
digraph {
	node [shape=circle]
	1 [label="(7,2)"]
	2 [label="(5,4)"]
	3 [label="(2,3)"]
	4 [label="(6,6)", style="invis"]
	1 -> 2
	1 -> 4
	1 -> 3 [style="invis"]
}

        我将文档保存在了C:\File\demo.dot路径下。在命令行中执行:

cd C:\File
dot -Tpng demo.png -o demo.dot

        就会在路径C:\File下生成demo.png,图像如图所示。
demo.png

demo.png
        下面解释dot文件中每一行的含义。
digraph G{
	...
}
表示这是一个有向图,图中的边都带箭头。
...
	node [shape='circle']
...
表示图中的节点都是圆形。
1 [label='(7,2)']
声明一个节点,节点记为1,其内容为字符串"(7,2)"
4 [label="(6,6)", style="invis"]
声明一个节点,节点记为4,其内容为字符串"(6,6)",并且这个节点在图中不显示。
1 -> 2
声明一条从1指向2的边。
1 -> 3 [style="invis"]
声明一条从1指向3的边,并且这条边在图中不显示。

        据此,为了保证二叉树有序、对齐显示,我们在绘制二叉树的过程中,左右子树中间添加一个不可见的边和不可见节点,实现图形的对齐效果。如果使用C++强行绘制graphviz,就根据.dot文件的语法格式,向文件流中采用先根遍历的方法书写dot文本,使用文件流记得#include <fstream>。C++实现方法如下:

#include <fstream>
void drawKDTree(node* root, string path) {
	// 等价于先根序列。
	// path = "tree.dot"
	ofstream fout(path);
	string tab = "    ";
	fout << "digraph G{" << endl;
	fout << tab << "node[shape=circle]" << endl;
	int N = data.size()+1;
	preOrderDraw(root, fout, N);
	fout << "}" << endl;
	fout.close();
	
}
void preOrderDraw(node* root, ofstream& fout, int& nullIndex) {
	string tab = "    ";
	// 先根序列,绘制当前节点的内容。
	fout << tab << root->index << "[group=" << root->index << ", label=\"(" << data[root->index][0];
	for (int i = 1; i < n; i++) {
		fout << "," << data[root->index][i];
	}
	fout << ")\"]" << endl;
	// 绘制左节点的内容
	if (root->left) {
		// 当左节点非空的时候,需要绘制一条伸向左节点的有向边。
		fout << tab << root->index << " -> " << root->left->index << endl;
		// 递归遍历左子树。
		preOrderDraw(root->left, fout, nullIndex);
	}
	else {
		// 左节点为空的时候,为了保证图形的整洁有序,绘制左侧空节点占位。边与节点都为不可见[style=invis]。
		fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl;
		fout << tab << "_" << nullIndex++ << " [style=invis]" << endl;
	}
	// 为了二叉树的图形可以相当漂亮美观且对齐,设置一个中间空节点保证左右两侧对齐。
	fout << tab << root->index<<" -> "<< "_" << root->index << "[weight=10, group=" << root->index << ", style=invis]" << endl;
	fout << tab << "_" << root->index << "[style=invis]" << endl;
	// 同上,绘制右节点的内容。
	if (root->right) {
		// 当右节点非空的时候,需要绘制一条伸向右节点的有向边。
		fout << tab << root->index << " -> " << root->right->index << endl;
		// 递归遍历右子树。
		preOrderDraw(root->right, fout, nullIndex);
	}
	else {
		// 右节点为空的时候,为了保证图形的整洁有序,绘制右侧空节点占位。边与节点都为不可见[style=invis]。
		fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl;
		fout << tab << "_" << nullIndex++ << " [style=invis]" << endl;
	}
}

        下面介绍Python的graphviz语法。为了绘制同样一棵上面的树,我们只需要做这几行代码,即可生成一棵二叉树并展示出来。

// demo.py
from graphviz import Digraph
dot = Digraph(node_attr={'shape': 'circle'})
dot.node(1,"(7,2)")
dot.node(2,"(5,4)")
dot.node(3,"(2,3)")
dot.node(4,"(6,6)",style="invis")
dot.edge(1,2)
dot.edge(1,4)
dot.edge(1,3,style="invis")
dot.view()

        据此,使用先根遍历的方式,同样根据二叉树的节点,绘制边和点即可。

import numpy as np
from graphviz import Digraph
from matplotlib import pyplot as plt
from matplotlib.pyplot import MultipleLocator
#data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]]
data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]]
data = np.array(data)

# 节点
class node:
	def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True):
		self.data = _data
		self.left = _left
		self.right = _right
		self.father = _father
		self.dim = _dim
		self.index = _index
		self.visiable = _visiable
	def getData(self):
		s = "("
		for i in range(self.data.size):
			if i!=0:
				s += ','
			s+=str(self.data[i])
		s += ")"
		return s
	def __str__(self):
		if(self.visiable):
			return str(self.index)
		else:
			return "_invis"+str(self.index)

dataIndex = 1
def drawKDTree(data, depth, k, dot):
	# 根据数据生成KD树
	dim = depth % k
	length = data.shape[0]
	if(length==0):
		return None, dot
	index = []
	for i in range(length):
		index.append([data[i][dim], i])
	index.sort()
	root = data[index[length//2][1]]
	left = [data[index[i][1]] for i in range(length//2)]
	left = np.array(left)
	right = [data[index[i][1]] for i in range(length//2+1, length)]
	right = np.array(right)
	global dataIndex
	root_node = node(_data=root, _dim=dim, _index=dataIndex)
	dataIndex+=1

	dot.node(str(root_node.index), root_node.getData())

	root_node.left, dot=drawKDTree(left, depth+1, k, dot)
	if(root_node.left==None):
		pass
		dot.node("_left"+str(root_node.index), root_node.getData(), style="invis")
		dot.edge(str(root_node.index), "_left"+str(root_node.index), style="invis")
	else:
		dot.edge(str(root_node.index), str(root_node.left.index))

	dot.node("_middle"+str(root_node.index), root_node.getData(), style="invis")
	dot.edge(str(root_node.index), "_middle"+str(root_node.index), style="invis", weight="10")

	root_node.right, dot=drawKDTree(right, depth+1, k, dot)

	if(root_node.right==None):
		pass
		dot.node("_right"+str(root_node.index), root_node.getData(), style="invis")
		dot.edge(str(root_node.index), "_right"+str(root_node.index), style="invis")
	else:
		dot.edge(str(root_node.index), str(root_node.right.index))

	if(root_node.left):
		root_node.left.father=root_node
	if(root_node.right):
		root_node.right.father=root_node
	
	return root_node, dot

dot = Digraph(node_attr={'shape': 'circle'})
_, dot = drawKDTree(data, 0, 2, dot)
dot.view()
print(dot.source)

绘制平面KNN模拟图

        只需要通过pyplot在生成KD树的过程中,控制节点的维度以及左右边界,即可绘制分类的直线段;通过绘制scatter散点图,将点标记在图中;通过计算半径,绘制以待查询节点为圆心的圆形。特别注意的是,由于pyplot不支持深拷贝、也无法撤销某一步操作,因此想要在同一个背景下绘制不同的图形,只有自己设置一个函数以保证每次都可以同样调用生成同一块背景,并在该背景上绘制新的图形。这里的Python函数不包括数据的预处理、标签、投票等内容,仅仅是用于绘制图形而用的脚本内容。

import numpy as np
from graphviz import Digraph
from matplotlib import pyplot as plt
from matplotlib.pyplot import MultipleLocator
#data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]]
data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]]
data = np.array(data)

# 节点
class node:
	def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True):
		self.data = _data
		self.left = _left
		self.right = _right
		self.father = _father
		self.dim = _dim
		self.index = _index
		self.visiable = _visiable
	def getData(self):
		s = "("
		for i in range(self.data.size):
			if i!=0:
				s += ','
			s+=str(self.data[i])
		s += ")"
		return s
	def __str__(self):
		if(self.visiable):
			return str(self.index)
		else:
			return "_invis"+str(self.index)

# 生成KD树,并绘制一个完整的平面图形。
def createTree(data, depth, k, l, r, d, u):
	dim = depth % k
	length = data.shape[0]
	if(length==0):
		return None
	index = []
	for i in range(length):
		index.append([data[i][dim], i])
	index.sort()
	root = data[index[length//2][1]]
	left = [data[index[i][1]] for i in range(length//2)]
	left = np.array(left)
	right = [data[index[i][1]] for i in range(length//2+1, length)]
	right = np.array(right)
	root_node = node(_data=root, _dim=dim)

	if(dim == 0):
		plt.plot([root[0]]*(u-d+1), range(d, u+1))
		root_node.left=createTree(left, depth+1, k, l, root[0], d, u)
		root_node.right=createTree(right, depth+1, k, root[0], r, d, u)
		if(root_node.left):
			root_node.left.father=root_node
		if(root_node.right):
			root_node.right.father=root_node
	else:
		plt.plot(range(l, r+1), [root[1]]*(r-l+1))
		root_node.left=createTree(left, depth+1, k, l, r, d, root[1])
		root_node.right=createTree(right, depth+1, k, l, r, root[1], u)
		if(root_node.left):
			root_node.left.father=root_node
		if(root_node.right):
			root_node.right.father=root_node
	return root_node

# 绘制分类超平面
def drawOri(data):
	fig, ax = plt.subplots()
	fig.set_size_inches(5, 5)
	data = np.array(data)
	mmax = np.max(data)+1
	mmin = np.min(data)-1
	major_locator=MultipleLocator(1)
	plt.scatter(data[:,0], data[:,1])
	plt.xlim(mmin, mmax)
	plt.ylim(mmin, mmax)
	ax = plt.gca()
	ax.xaxis.set_major_locator(major_locator)
	ax.yaxis.set_major_locator(major_locator)
	return createTree(data, 0, 2, mmin, mmax, mmin, mmax)

# 绘制标记点及分类超平面
def drawPic(x, data):
	fig, ax = plt.subplots()
	fig.set_size_inches(5, 5)
	data = np.array(data)
	mmax = np.max(data)+1
	mmin = np.min(data)-1
	major_locator=MultipleLocator(1)
	plt.scatter(data[:,0], data[:,1])
	plt.scatter([x[0]], [x[1]], marker='x')
	plt.xlim(mmin, mmax)
	plt.ylim(mmin, mmax)
	ax = plt.gca()
	ax.xaxis.set_major_locator(major_locator)
	ax.yaxis.set_major_locator(major_locator)
	return createTree(data, 0, 2, mmin, mmax, mmin, mmax)

# 计算两点间的欧式距离
def distance(a, b):
	return ((a[0]-b[0])**2+(a[1]-b[1])**2)**0.5

# 寻找叶节点
def findLeaf(root, x, stack):
	if(root==None):
		return stack
	stack.append(root)
	if(x[root.dim]<=root.data[root.dim]):
		return findLeaf(root.left, x, stack)
	else:
		return findLeaf(root.right, x, stack)


# 寻找最近邻节点,并绘制图形
def searchNearest(root, x, differt_pic=True, show=False):
	plt.scatter([x[0]], [x[1]], marker='x')
	stack = []
	stack = findLeaf(root, x, stack)
	nearN = stack[-1]
	minD = distance(stack[-1].data, x)
	visted = set()
	path = 1
	while(stack):
		top = stack[-1]
		visted.add(top)
		stack.pop()
		dis = distance(top.data, x)
		if(dis < minD):
			minD = dis
			nearN = top
		if show:
			plt.show()
		if differt_pic:
			# 重新绘制一张底图。
			drawPic(x, data)
		ax = plt.gca()
		ax.scatter(top.data[0], top.data[1], marker='x', s=200)
		ax.plot([x[0], nearN.data[0]], [x[1], nearN.data[1]])
		theta = np.arange(0, 2*np.pi, 0.01)
		xx = x[0] + minD * np.cos(theta)
		yy = x[1] + minD * np.sin(theta)
		plt.plot(xx, yy)
		plt.savefig("{0}.png".format(path))
		path += 1

		left = x[top.dim] - minD
		right = x[top.dim] + minD
		if(left <= top.data[top.dim] and top.left != None and top.left not in visted):
			stack.append(top.left)
		if(right >= top.data[top.dim] and top.right != None and top.right not in visted):
			stack.append(top.right)
	return nearN


x = [4, 3]
drawOri(data)
plt.savefig("0.png")
root = drawPic(x, data, differt_pic=True, show=False)
searchNearest(root, x)

        绘制图形展示如下:
1
2

        至此,《统计学习方法》第三章的全部内容都更新完毕,在我的Gtihub中有详细代码,欢迎查阅。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ProfSnail

谢谢老哥嗷

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值