观察迭代过程中神经网络提取的特征的数据分布情况,十类别+二维显示+图例+去除上边框、右边框+保存图片
安装包:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')
from sklearn.manifold import TSNE
画图程序:
def plot_embedding(data, label, title):
“data为n*2矩阵,label为n*1向量,对应着data的标签,title未使用”
fig = plt.figure()
ax = plt.subplot(111)
type1_x = []
type1_y = []
type2_x = []
type2_y = []
type3_x = []
type3_y = []
type4_x = []
type4_y = []
type5_x = []
type5_y = []
type6_x = []
type6_y = []
type7_x = []
type7_y = []
type8_x = []
type8_y = []
type9_x = []
type9_y = []
type10_x = []
type10_y = []
type11_x = []
type11_y = []
for i in range(data.shape[0]):
if label[i] == 0:
type1_x.append(data[i][0])
type1_y.append(data[i][1])
if label[i] == 1:
type2_x.append(data[i][0])
type2_y.append(data[i][1])
if label[i] == 2:
type3_x.append(data[i][0])
type3_y.append(data[i][1])
if label[i] == 3:
type4_x.append(data[i][0])
type4_y.append(data[i][1])
if label[i] == 4:
type5_x.append(data[i][0])
type5_y.append(data[i][1])
if label[i] == 5:
type6_x.append(data[i][0])
type6_y.append(data[i][1])
if label[i] == 6:
type7_x.append(data[i][0])
type7_y.append(data[i][1])
if label[i] == 7:
type8_x.append(data[i][0])
type8_y.append(data[i][1])
if label[i] == 8:
type9_x.append(data[i][0])
type9_y.append(data[i][1])
if label[i] == 9:
type10_x.append(data[i][0])
type10_y.append(data[i][1])
if label[i] == 10:
type11_x.append(data[i][0])
type11_y.append(data[i][1])
color = plt.cm.Set3(0)
color = np.array(color).reshape(1, 4)
color1 = plt.cm.Set3(1)
color1 = np.array(color1).reshape(1, 4)
color2 = plt.cm.Set3(2)
color2 = np.array(color2).reshape(1, 4)
color3 = plt.cm.Set3(3)
color3 = np.array(color3).reshape(1, 4)
type1 = plt.scatter(type1_x, type1_y, s=10, c='r')
type2 = plt.scatter(type2_x, type2_y, s=10, c='g')
type3 = plt.scatter(type3_x, type3_y, s=10, c='b')
type4 = plt.scatter(type4_x, type4_y, s=10, c='k')
type5 = plt.scatter(type5_x, type5_y, s=10, c='c')
type6 = plt.scatter(type6_x, type6_y, s=10, c='m')
type7 = plt.scatter(type7_x, type7_y, s=10, c='y')
type8 = plt.scatter(type8_x, type8_y, s=10, c=color)
type9 = plt.scatter(type9_x, type9_y, s=10, c=color1)
type10 = plt.scatter(type10_x, type10_y, s=10, c=color2)
#type11 = plt.scatter(type11_x, type11_y, s=10, c='r')
plt.legend((type1, type2, type3, type4, type5, type6, type7, type8, type9, type10),
('N', 'B0', 'B1', 'B2', 'I0', 'I1', 'I2', 'O0', 'O1', 'O3'),
loc=(0.97, 0.5))
#plt.xticks(np.linspace(int(x_min[0]), math.ceil(x_max[0]), 5))
#plt.yticks(np.linspace(int(x_min[1]), math.ceil(x_max[1]), 5))
plt.xticks()
plt.yticks()
# plt.title(title)
ax.spines['right'].set_visible(False) #去除右边框
ax.spines['top'].set_visible(False) #去除上边框
return fig
使用TSNE对提取的特征进行降维,然后画图:
def plot_2D(data, label,epoch):
“data为提取的特征数据,epoch未使用”
#n_samples, n_features = data.shape
print('Computing t-SNE embedding')
tsne = TSNE(n_components=2, init='pca', random_state=0) #使用TSNE对特征降到二维
#t0 = time()
result = tsne.fit_transform(data) #降维后的数据
#print(result.shape)
#画图
fig = plot_embedding(result, label,
't-SNE embedding of the digits (time %.2fs)')
#% (time() - t0))
fig.subplots_adjust(right=0.8) #图例过大,保存figure时无法保存完全,故对此参数进行调整
保存figure到一定文件夹:
plt.savefig("D:/image_train/'" + str(epoch) + "'.png")
保存的图片: