文章目录
本文给出一个自行生成数据集的代码,在使用时,给出类别数C,特征数D,每个类别的样本数量N,各组特征均值,程序即可自动生成数据集。另外针对类别数为2和3的情况,可以绘制可视化坐标图。
'''
用于生成数据集,用户给定类别数C,特征数D,每个类别的样本数量N,各组特征均值
生成特征的方差为1
'''
import numpy as np
from matplotlib import pyplot as plt
class MakeData:
def __init__(self, num_classes, num_features, num_samples, Mean):
self.C = num_classes#类别数
self.D = num_features#特征数
self.N = num_samples#样本数
self.Mean = Mean#均值集合
def produce_data(self):
cov = np.eye(self.D)#方差矩阵
'''
生成第一个类别的样本
'''
X1 = np.random.multivariate_normal(self.Mean[0], cov, self.N)
y1 = np.zeros(self.N)
X = X1
y = y1
for i in range(1,self.C):
X_ = np.random.multivariate_normal(self.Mean[i], cov, self.N)
y_ = np.ones(self.N)*i
X = np.concatenate((X, X_), axis=0)
y = np.concatenate((y, y_))
return X,y
def show_scatter(self):
X,y = self.produce_data()
if self.D == 2:
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1, edgecolor='k')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Dataset')
plt.show()
elif self.D == 3:
ax = plt.subplot(projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], c=y, cmap=plt.cm.Set1)
plt.show()
使用
if __name__ == '__main__':
M = [[3,3],[-3,-3]]
data = MakeData(2, 2, 50, M)
data.show_scatter()
应生成2类有两个特征的若干样本
if __name__ == '__main__':
M = [[3,3,3],[-3,-3,-3]]
data = MakeData(2, 3, 50, M)
data.show_scatter()
应生成2类有三个特征的若干样本