本文介绍如何使用NumPy相关的数据实现数据的3D散点图可视化。
Updated: 2022 / 03 / 06
数据导入及清洗
先导入坐标集数据,再进行针对性的数据清洗。
请点击此处来下载原数据文件(.npy)。
导入
这个坐标集数据是以[a, b, c, d]
的形式集合,其中[a, b, c]
视为坐标[x, z, y]
以及[d]
视为相应坐标的ID
。
以下方导入的示例数据中的[2, 2, 1920, 480]
为例,[2, 2, 1920]
对应坐标系中的[x, z, y]
,[480]
仅仅为[2, 2, 1920]
的ID
识别码。
DATA = np.load('Exampledata_3D_scatter.npy')
'''
[([ 2, 2, 1920, 480],) ([ 1, 3, 1923, 480],)
......
([ 4, 1, 1920, 480],) ([ 3, 0, 1922, 480],)
......
([ 3, 3, 1923, 480],)]
⬆️ .shape, (69,)
⬆️ data.dtype, [('f0', '<i8', (4,))]
⬆️ type(data), <class 'numpy.ndarray'>
⬆️ type(data[0]), <class 'numpy.void'>
'''
其中,对于
numpy.void
1’ 2的处理, 截止目前为止,有2种方法:
np.array([DATA[i][0] for i in range(DATA.shape[0])])
data = np.array([DATA[i][0] for i in range(DATA.shape[0])]) ''' [[ 2 2 1920 480] ...... [ 3 3 1923 480]] ⬆️ data.shape, (69, 4); data.dtype, [('f0', '<i8', (4,))]; type(data), <class 'numpy.ndarray'> data[-1], ([ 3, 3, 1923, 480],) ⬆️ type(data[-1]), <class 'numpy.void'>; data[-1].dtype, [('f0', '<i8', (4,))] '''
np.array(DATA.tolist()).squeeze(axis=1)
data = np.array(DATA.tolist()).squeeze(axis=1) ''' DATA.tolist() [(array([ 2, 2, 1920, 480]),), ..., (array([ 3, 3, 1923, 480]),)] ⬆️ type(DATA.tolist()), <class 'list'> np.array(DATA.tolist()) [[[ 2 2 1920 480]] [[ 1 3 1923 480]] .... [[ 3 3 1923 480]]] ⬆️ np.array(DATA.tolist()).shape, (69, 1, 4); np.array(DATA.tolist()).dtype, int64; type(np.array(DATA.tolist())), <class 'numpy.ndarray'>; np.array(DATA.tolist()).squeeze(axis=1) [[ 2 2 1920 480] [ 1 3 1923 480] ...... [ 3 3 1923 480]] ⬆️ np.array(DATA.tolist()).squeeze(axis=1).shape, (69, 4); np.array(DATA.tolist()).squeeze(axis=1).dtype, int64; type(np.array(DATA.tolist()).squeeze(axis=1)),<class 'numpy.ndarray'> '''
data = np.array(DATA.tolist()).squeeze(axis=1)
'''
[[ 2 2 1920 480]
[ 1 3 1923 480]
......
[ 3 3 1923 480]]
⬆️ np.array(DATA.tolist()).squeeze(axis=1).shape, (69, 4);
np.array(DATA.tolist()).squeeze(axis=1).dtype, int64;
type(np.array(DATA.tolist()).squeeze(axis=1)),<class 'numpy.ndarray'>
'''
清洗
获取全部坐标
以[2, 2, 1920, 480]
为例,[2, 2, 1920]
对应坐标系中的[x, z, y]
,我们只在意第0、1、2列。而此处有3列。为达到目的,可以使用np.hsplit
3和[:, 0:3]
4来进行数组的切片。
X, Y, Z = np.hsplit(data[:, 0:3], 3)
'''
[[2]
...
[3]]
⬆️ X.shape, (69, 1)
[[2]
...
[3]]
⬆️ Y.shape, (69, 1)
[[1920]
...
[1923]]
⬆️ Z.shape, (69, 1)
'''
如果试图切割数组的所有列为单独的数组,
x, y, z, m = np.hsplit(data, data.shape[1])
筛选特定坐标
以[ 7 3 1923 480]
为例,对应坐标系中的[7, 3, 1923]
,我们只在意坐标集中x
为7的坐标,因此为达到目的可以使用np.where(condition, x, y)
56来对数组进行筛选。
x7 = data[np.where(data[:, 0]==7)]
'''
[[ 7 3 1923 480]
...
[ 7 1 1921 480]]
⬆️ x7.shape, (9, 4); x7.dtype, int64; type(x7), <class 'numpy.ndarray'>
'''
x7x, x7y, x7z = np.hsplit(x7[:, 0:3], 3)
'''
[[7]
......
[7]]
⬆️ x7x.shape, (9, 1)
[[3]
......
[1]]
⬆️ x7y.shape, (9, 1)
[[1923]
......
[1921]]
⬆️ x7z.shape, (9, 1)
'''
数据3D可视化
此部分使用清洗
中给出的示例数据介绍如何将数据实现3D可视化。
引入所需的库,
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
散点图
单图
fig = plt.figure(figsize=(6, 6))
ax = plt.axes(projection='3d')
# [7, y, z, ID]
i = 7
xi = data[np.where(data[:, 0] == i)]
xix, xiy, xiz = np.hsplit(xi[:, 0:3], 3)
ax.scatter3D(xix, xiy, xiz, s=40, c=xiz, marker='o', alpha=0.8, edgecolor='white')
# ticks
tickls, tickc = 5, 'black'
xmin, xmax = data[:, 0].min(), data[:, 0].max()
ymin, ymax = data[:, 1].min(), data[:, 1].max()
zmin, zmax = data[:, 2].min(), data[:, 2].max()
ax.set_xlim(xmin=xmin-1, xmax=xmax+1)
ax.set_xticklabels(range(xmin-1, xmax+1, 1), color=tickc)
ax.tick_params(axis='x', labelsize=tickls)
ax.set_ylim(bottom=ymin - 1, top=ymax + 1)
ax.set_yticklabels(range(ymin - 1, ymax + 1, 1), color=tickc)
ax.tick_params(axis='y', labelsize=tickls)
ax.set_zlim(bottom=zmin - 1, top=zmax + 1)
ax.set_zticklabels(range(zmin - 1, zmax + 1, 1), color=tickc)
ax.tick_params(axis='z', labelsize=tickls)
# test
for idx in range(xi.shape[0]):
ax.text(x=xix[idx][0], y=xiy[idx][0], z=xiz[idx][0], s=xi[idx][1:3], zdir='x', fontsize=5)
# labels
labelfd = {'size': 8, 'color': 'black'}
ax.set_xlabel('X', fontdict=labelfd)
ax.set_ylabel('Y', fontdict=labelfd)
ax.set_zlabel('Z', fontdict=labelfd)
# title
ax.set_title(f"x={i}", loc='left', fontsize=8)
plt.suptitle("Single(x=7)", x=0.5, y=0.92, fontsize=16, color='red')
plt.savefig("NumPy_Ex1_3Dscatter_Single(x=7).png", dpi=300)
效果图如下所示:
多子图
fig = plt.figure(figsize=(12, 12))
for i in range(8):
ax = fig.add_subplot(2, 4, i+1, projection='3d') # create subplot
x7 = data[np.where(data[:, 0] == 7)]
x7x, x7y, x7z = np.hsplit(x7[:, 0:3], 3) # [0~7, y, z, ID]
ax.scatter3D(x7x, x7y, x7z, s=40, c=x7z, marker='o')
# ticks
tickls, tickc = 5, 'black'
xmin, xmax = data[:, 0].min(), data[:, 0].max()
ymin, ymax = data[:, 1].min(), data[:, 1].max()
zmin, zmax = data[:, 2].min(), data[:, 2].max()
ax.set_xlim(xmin=xmin - 1, xmax=xmax + 1)
ax.set_xticklabels(range(xmin - 1, xmax + 1, 1), color=tickc)
ax.tick_params(axis='x', labelsize=tickls)
ax.set_ylim(bottom=ymin - 1, top=ymax + 1)
ax.set_yticklabels(range(ymin - 1, ymax + 1, 1), color=tickc)
ax.tick_params(axis='y', labelsize=tickls)
ax.set_zlim(bottom=zmin - 1, top=zmax + 1)
ax.set_zticklabels(range(zmin - 1, zmax + 1, 1), color=tickc)
ax.tick_params(axis='z', labelsize=tickls)
# annotation
for idx in range(x7.shape[0]):
ax.text(x=x7x[idx][0], y=x7y[idx][0], z=x7z[idx][0], s=x7[idx][1:3], zdir='x', fontsize=3)
# labels
labelfd = {'size': 8, 'color': 'black'}
ax.set_xlabel('X', fontdict=labelfd)
ax.set_ylabel('Y', fontdict=labelfd)
ax.set_zlabel('Z', fontdict=labelfd)
# title
ax.set_title(f"x={7}", loc='right', fontsize=8)
plt.suptitle("x=7", x=0.5, y=0.88, fontsize=16, color='red')
plt.savefig("8*Subplots(x=7).png", dpi=300)
效果图如下所示:
参考链接
%数据处理
% 画图
% 注释
%坐标轴
% 图例
%标题
% 保存