1. 引入模块,读取数据
2. 构建计算图(构建网络模型)
3. 损失函数与优化器
4. 开始训练模型
5. 对训练的模型预测结果进行评估
数据采用糖尿病分类数据集diabetes.csv。这是一个典型的分类问题数据,包含768个样本,每个样本的数据包含8个特征,分别代表受试者的不同身体指标,标签为0或1,代表是否患有糖尿病。数据集示意图如下:
import torch
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
import pandas as pd
from torch.autograd import Variable
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
#从本地读取数据
xy = pd.read_csv('./diabetes.csv').values
x = Variable(torch.from_numpy(xy[:,0:-1]))
y = Variable(torch.from_numpy(xy[:,-1]))
#划分训练数据和测试数据
x_train, x_test,y_train,y_test= train_test_split(x.numpy