from:https://www.jianshu.com/p/0d7438a5acb6
我的代码:
import paddle
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
print(paddle.__version__)
# --------------step1----------------
# 生成训练数据与标签
x1 = np.random.normal(6, 1, size=(100))
x2 = np.random.normal(3, 1, size=(100))
y = np.ones(100) # label:1
class1 = np.array([x1, x2, y]).T
y = np.zeros(100) # label:0
class2 = np.array([x2, x1, y]).T
print(class1.shape, class2.shape)
#查看数据集图像
# plt.scatter(class1[:, 0], class1[:, 1], c='r')
# plt.scatter(class2[:, 0], class2[:, 1], c='b', marker='*')
# plt.show()
# --------------step2----------------
#合并数据并将数据打乱
all_data = np.concatenate((class1, class2)) # 合并数据
print(all_data.shape)
np.random.shuffle(all_data) # 打乱数