数据集
wine数据集包含三种葡萄酒类别,总共178个样本,每个样本具有13个特征,样本数据格式如下图所示。
数据读取及预处理
从wine.data中读取数据(loadDateSet)并进行降维(LL)处理
def loadDateSet(filename):
dataMat = []
labelMat = []
fr = open(filename)
for line in fr.readlines():
curLine = line.strip().split(',')
fltline = list(map(float,curLine[1:]))
dataMat.append(fltline)
labelline = int(curLine[0])
labelMat.append(labelline)
return np.array(dataMat),np.array(labelMat)
def LL(x,y):
x_norm = preprocessing.normalize(x,norm = 'l2')
lda = LinearDiscriminantAnalysis(n_components=2)
x_new = lda.fit_transform(x_norm,y)
return x_new```
构建tensor格式训练集及测试集
构建tensor数据集函数(Data.TensorDataset)
dataMat, labelMat = loadDateSet('wine.data')
dataMat = LL(dataMat,labelMat)
pindex=np.random.permutation(dataMat.shape[0])
dataMat = dataMat[pindex,:]
labelMat = labelMat[pindex]
dataMat = torch.from_numpy(dataMat)
labelMat = torch.from_numpy(labelMat)
torch_dataset = Data.TensorDataset(dataMat[28:], labelMat[28:])
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=15,
shuffle=True,
num_workers=2
)
torch_testset = Data.TensorDataset(dataMat[0:27], labelMat[0:27])
loader2 = Data.DataLoader(
dataset=torch_testset,
batch_size=29,
shuffle=True,
num_workers=2
)
搭建MLP
class MLP(torch.nn.Module):
def __init__