前言
我在另一篇博文机器学习(3) K近邻算法(KNN)介绍及C++实现中介绍了K近邻算法及KD树的实现方法,博文编写过程中需要显式绘制二叉树将其表示出来。初始方法是使用C++生成KD树,并根据graphviz的dot语言逐行编写KD树。通过查阅相关文献,发现使用Python绘制KD树的过程并不繁琐,于是本文介绍使用Python绘制graphviz二叉查找树的图形。另外,介绍博文中是绘制二维平面上KD树、收敛过程如何表示的图形化绘制方法
graphviz简介
graphviz是一种便于绘制流程图、树形结构等的图形可视化软件。掌握基础的脚本语言就可以轻松绘制属于自己的流程图、二叉树图等内容。
graphviz安装
安装graphviz流程:打开graphviz下载链接,依据网页提示选择属于自己的平台安装包。我安装的是windows10下的stable_windows_10_msbuild_Release_Win32_graphviz-2.46.0-win32.zip。下载完成后解压到C:/Software/graphviz等自己习惯的路径下,将C:/software/graphviz/bin加入到系统环境变量中,重启电脑以配置graphviz环境变量。使用Python绘制graphviz流程图,需要在安装python3环境后,在命令行pip install graphviz即可。
安装python3环境流程:以windows10为例,打开Python3-Windows-下载地址,下载最新版本安装包或适合自己的版本安装包,安装到本地过程中记得配置环境变量Path,不再赘述。
graphviz语法
我将着重讲述使用graphviz绘制二叉树涉及到的语法知识,更详细的语法知识参见graphviz官网说明文档。我将讲解两方面知识,第一是使用命令行编译dot文件,第二是使用python直接生成.gv文件。
首先介绍命令行下如何使用graphviz语法编写一个二叉树。在自己的路径下生成一个demo.dot文档,文档中内容如下:
// demo.dot
digraph {
node [shape=circle]
1 [label="(7,2)"]
2 [label="(5,4)"]
3 [label="(2,3)"]
4 [label="(6,6)", style="invis"]
1 -> 2
1 -> 4
1 -> 3 [style="invis"]
}
我将文档保存在了C:\File\demo.dot路径下。在命令行中执行:
cd C:\File
dot -Tpng demo.png -o demo.dot
就会在路径C:\File下生成demo.png,图像如图所示。
digraph G{
...
}
表示这是一个有向图,图中的边都带箭头。
...
node [shape='circle']
...
表示图中的节点都是圆形。
1 [label='(7,2)']
声明一个节点,节点记为1,其内容为字符串"(7,2)"
4 [label="(6,6)", style="invis"]
声明一个节点,节点记为4,其内容为字符串"(6,6)",并且这个节点在图中不显示。
1 -> 2
声明一条从1指向2的边。
1 -> 3 [style="invis"]
声明一条从1指向3的边,并且这条边在图中不显示。
据此,为了保证二叉树有序、对齐显示,我们在绘制二叉树的过程中,左右子树中间添加一个不可见的边和不可见节点,实现图形的对齐效果。如果使用C++强行绘制graphviz,就根据.dot文件的语法格式,向文件流中采用先根遍历的方法书写dot文本,使用文件流记得#include <fstream>
。C++实现方法如下:
#include <fstream>
void drawKDTree(node* root, string path) {
// 等价于先根序列。
// path = "tree.dot"
ofstream fout(path);
string tab = " ";
fout << "digraph G{" << endl;
fout << tab << "node[shape=circle]" << endl;
int N = data.size()+1;
preOrderDraw(root, fout, N);
fout << "}" << endl;
fout.close();
}
void preOrderDraw(node* root, ofstream& fout, int& nullIndex) {
string tab = " ";
// 先根序列,绘制当前节点的内容。
fout << tab << root->index << "[group=" << root->index << ", label=\"(" << data[root->index][0];
for (int i = 1; i < n; i++) {
fout << "," << data[root->index][i];
}
fout << ")\"]" << endl;
// 绘制左节点的内容
if (root->left) {
// 当左节点非空的时候,需要绘制一条伸向左节点的有向边。
fout << tab << root->index << " -> " << root->left->index << endl;
// 递归遍历左子树。
preOrderDraw(root->left, fout, nullIndex);
}
else {
// 左节点为空的时候,为了保证图形的整洁有序,绘制左侧空节点占位。边与节点都为不可见[style=invis]。
fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl;
fout << tab << "_" << nullIndex++ << " [style=invis]" << endl;
}
// 为了二叉树的图形可以相当漂亮美观且对齐,设置一个中间空节点保证左右两侧对齐。
fout << tab << root->index<<" -> "<< "_" << root->index << "[weight=10, group=" << root->index << ", style=invis]" << endl;
fout << tab << "_" << root->index << "[style=invis]" << endl;
// 同上,绘制右节点的内容。
if (root->right) {
// 当右节点非空的时候,需要绘制一条伸向右节点的有向边。
fout << tab << root->index << " -> " << root->right->index << endl;
// 递归遍历右子树。
preOrderDraw(root->right, fout, nullIndex);
}
else {
// 右节点为空的时候,为了保证图形的整洁有序,绘制右侧空节点占位。边与节点都为不可见[style=invis]。
fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl;
fout << tab << "_" << nullIndex++ << " [style=invis]" << endl;
}
}
下面介绍Python的graphviz语法。为了绘制同样一棵上面的树,我们只需要做这几行代码,即可生成一棵二叉树并展示出来。
// demo.py
from graphviz import Digraph
dot = Digraph(node_attr={'shape': 'circle'})
dot.node(1,"(7,2)")
dot.node(2,"(5,4)")
dot.node(3,"(2,3)")
dot.node(4,"(6,6)",style="invis")
dot.edge(1,2)
dot.edge(1,4)
dot.edge(1,3,style="invis")
dot.view()
据此,使用先根遍历的方式,同样根据二叉树的节点,绘制边和点即可。
import numpy as np
from graphviz import Digraph
from matplotlib import pyplot as plt
from matplotlib.pyplot import MultipleLocator
#data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]]
data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]]
data = np.array(data)
# 节点
class node:
def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True):
self.data = _data
self.left = _left
self.right = _right
self.father = _father
self.dim = _dim
self.index = _index
self.visiable = _visiable
def getData(self):
s = "("
for i in range(self.data.size):
if i!=0:
s += ','
s+=str(self.data[i])
s += ")"
return s
def __str__(self):
if(self.visiable):
return str(self.index)
else:
return "_invis"+str(self.index)
dataIndex = 1
def drawKDTree(data, depth, k, dot):
# 根据数据生成KD树
dim = depth % k
length = data.shape[0]
if(length==0):
return None, dot
index = []
for i in range(length):
index.append([data[i][dim], i])
index.sort()
root = data[index[length//2][1]]
left = [data[index[i][1]] for i in range(length//2)]
left = np.array(left)
right = [data[index[i][1]] for i in range(length//2+1, length)]
right = np.array(right)
global dataIndex
root_node = node(_data=root, _dim=dim, _index=dataIndex)
dataIndex+=1
dot.node(str(root_node.index), root_node.getData())
root_node.left, dot=drawKDTree(left, depth+1, k, dot)
if(root_node.left==None):
pass
dot.node("_left"+str(root_node.index), root_node.getData(), style="invis")
dot.edge(str(root_node.index), "_left"+str(root_node.index), style="invis")
else:
dot.edge(str(root_node.index), str(root_node.left.index))
dot.node("_middle"+str(root_node.index), root_node.getData(), style="invis")
dot.edge(str(root_node.index), "_middle"+str(root_node.index), style="invis", weight="10")
root_node.right, dot=drawKDTree(right, depth+1, k, dot)
if(root_node.right==None):
pass
dot.node("_right"+str(root_node.index), root_node.getData(), style="invis")
dot.edge(str(root_node.index), "_right"+str(root_node.index), style="invis")
else:
dot.edge(str(root_node.index), str(root_node.right.index))
if(root_node.left):
root_node.left.father=root_node
if(root_node.right):
root_node.right.father=root_node
return root_node, dot
dot = Digraph(node_attr={'shape': 'circle'})
_, dot = drawKDTree(data, 0, 2, dot)
dot.view()
print(dot.source)
绘制平面KNN模拟图
只需要通过pyplot在生成KD树的过程中,控制节点的维度以及左右边界,即可绘制分类的直线段;通过绘制scatter散点图,将点标记在图中;通过计算半径,绘制以待查询节点为圆心的圆形。特别注意的是,由于pyplot不支持深拷贝、也无法撤销某一步操作,因此想要在同一个背景下绘制不同的图形,只有自己设置一个函数以保证每次都可以同样调用生成同一块背景,并在该背景上绘制新的图形。这里的Python函数不包括数据的预处理、标签、投票等内容,仅仅是用于绘制图形而用的脚本内容。
import numpy as np
from graphviz import Digraph
from matplotlib import pyplot as plt
from matplotlib.pyplot import MultipleLocator
#data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]]
data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]]
data = np.array(data)
# 节点
class node:
def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True):
self.data = _data
self.left = _left
self.right = _right
self.father = _father
self.dim = _dim
self.index = _index
self.visiable = _visiable
def getData(self):
s = "("
for i in range(self.data.size):
if i!=0:
s += ','
s+=str(self.data[i])
s += ")"
return s
def __str__(self):
if(self.visiable):
return str(self.index)
else:
return "_invis"+str(self.index)
# 生成KD树,并绘制一个完整的平面图形。
def createTree(data, depth, k, l, r, d, u):
dim = depth % k
length = data.shape[0]
if(length==0):
return None
index = []
for i in range(length):
index.append([data[i][dim], i])
index.sort()
root = data[index[length//2][1]]
left = [data[index[i][1]] for i in range(length//2)]
left = np.array(left)
right = [data[index[i][1]] for i in range(length//2+1, length)]
right = np.array(right)
root_node = node(_data=root, _dim=dim)
if(dim == 0):
plt.plot([root[0]]*(u-d+1), range(d, u+1))
root_node.left=createTree(left, depth+1, k, l, root[0], d, u)
root_node.right=createTree(right, depth+1, k, root[0], r, d, u)
if(root_node.left):
root_node.left.father=root_node
if(root_node.right):
root_node.right.father=root_node
else:
plt.plot(range(l, r+1), [root[1]]*(r-l+1))
root_node.left=createTree(left, depth+1, k, l, r, d, root[1])
root_node.right=createTree(right, depth+1, k, l, r, root[1], u)
if(root_node.left):
root_node.left.father=root_node
if(root_node.right):
root_node.right.father=root_node
return root_node
# 绘制分类超平面
def drawOri(data):
fig, ax = plt.subplots()
fig.set_size_inches(5, 5)
data = np.array(data)
mmax = np.max(data)+1
mmin = np.min(data)-1
major_locator=MultipleLocator(1)
plt.scatter(data[:,0], data[:,1])
plt.xlim(mmin, mmax)
plt.ylim(mmin, mmax)
ax = plt.gca()
ax.xaxis.set_major_locator(major_locator)
ax.yaxis.set_major_locator(major_locator)
return createTree(data, 0, 2, mmin, mmax, mmin, mmax)
# 绘制标记点及分类超平面
def drawPic(x, data):
fig, ax = plt.subplots()
fig.set_size_inches(5, 5)
data = np.array(data)
mmax = np.max(data)+1
mmin = np.min(data)-1
major_locator=MultipleLocator(1)
plt.scatter(data[:,0], data[:,1])
plt.scatter([x[0]], [x[1]], marker='x')
plt.xlim(mmin, mmax)
plt.ylim(mmin, mmax)
ax = plt.gca()
ax.xaxis.set_major_locator(major_locator)
ax.yaxis.set_major_locator(major_locator)
return createTree(data, 0, 2, mmin, mmax, mmin, mmax)
# 计算两点间的欧式距离
def distance(a, b):
return ((a[0]-b[0])**2+(a[1]-b[1])**2)**0.5
# 寻找叶节点
def findLeaf(root, x, stack):
if(root==None):
return stack
stack.append(root)
if(x[root.dim]<=root.data[root.dim]):
return findLeaf(root.left, x, stack)
else:
return findLeaf(root.right, x, stack)
# 寻找最近邻节点,并绘制图形
def searchNearest(root, x, differt_pic=True, show=False):
plt.scatter([x[0]], [x[1]], marker='x')
stack = []
stack = findLeaf(root, x, stack)
nearN = stack[-1]
minD = distance(stack[-1].data, x)
visted = set()
path = 1
while(stack):
top = stack[-1]
visted.add(top)
stack.pop()
dis = distance(top.data, x)
if(dis < minD):
minD = dis
nearN = top
if show:
plt.show()
if differt_pic:
# 重新绘制一张底图。
drawPic(x, data)
ax = plt.gca()
ax.scatter(top.data[0], top.data[1], marker='x', s=200)
ax.plot([x[0], nearN.data[0]], [x[1], nearN.data[1]])
theta = np.arange(0, 2*np.pi, 0.01)
xx = x[0] + minD * np.cos(theta)
yy = x[1] + minD * np.sin(theta)
plt.plot(xx, yy)
plt.savefig("{0}.png".format(path))
path += 1
left = x[top.dim] - minD
right = x[top.dim] + minD
if(left <= top.data[top.dim] and top.left != None and top.left not in visted):
stack.append(top.left)
if(right >= top.data[top.dim] and top.right != None and top.right not in visted):
stack.append(top.right)
return nearN
x = [4, 3]
drawOri(data)
plt.savefig("0.png")
root = drawPic(x, data, differt_pic=True, show=False)
searchNearest(root, x)
绘制图形展示如下:
至此,《统计学习方法》第三章的全部内容都更新完毕,在我的Gtihub中有详细代码,欢迎查阅。