import torch
import numpy as np
#这是一个抽象类,抽象类不能被实例化,只能被其他子类继承
from torch.utils.data import Dataset
#用来帮助我们加载数据的,可以用来实例化
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
#自定义类继承dataset
class DiabetesDataset(Dataset):
#init有两种选择1把所有数据都加载进来(数据集本身数据不大时),
# 2当数据文件过大时,将文件名放入列表,再使用getitem()通过索引根据文件名读入数据
def __init__(self,filepath):
xy=np.loadtxt(filepath,delimiter=',',dtype=np.float32)
# xy是一个N行9列的数据集(N是数据样本的数量,8是特征列,1是目标列 (N,9)的元组
#shape是(N,9)的元组,第0个元素就是N,第一个元素就是9
self.len=xy.shape[0]
self.x_data=torch.from_numpy(xy[:,:-1])
self.y_data=torch.from_numpy(xy[:,[-1]])
# 使对象能够通过索引进行下标操作
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
# 这两个都是魔法函数 有__
#len(xx)能够返回数据集的数据条数
def __len__(self):
return self.len
dataset=DiabetesDataset('diabetes.csv.gz')
#需要用loader进行迭代的代码要封装起来
train_loader=DataLoader(dataset=dataset,
batch_size=32,#一个小批量的容量是多少
shuffle=True,
num_workers=2)#2个并行进程读取数据
class Model(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
self.linear1=torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmod=torch.nn.Sigmoid()
def forward(self,x):
x=self.sigmod(self.linear1(x))
x=self.sigmod(self.linear2(x))
x=self.sigmod(self.linear3(x))
return x
model=Model()
#构造损失和优化器
criterion = torch.nn.BCELoss(reduction='mean')
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
epoch_list = []
loss_list = []
#epoch batchsize Iterations
'''
如果使用minibatch要使用嵌套循环,最外层循环是epoch
在每一次epoch里面再执行一次循环,这次循环每次迭代执行一次minibatch
epoch:所有训练样本都进行一次前向传播和反向传播
batchsize:每次训练的时候(一次前馈一次反馈)所用的样本数量
Iterations:迭代数量,是根据batchsize来决定的
shuffle:指的是对数据集进行shuffle,把顺序打乱,提高数据样本的随机性
'''
#由于windows与Linux处理多线程的库不同,需要在双重for外加上 if __name__ == '__main__' 防止报错
if __name__=='__main__':
for epoch in range(100):
#train_loader得到的元组(x,y)放在data里面
for i,data in enumerate(train_loader,0):#i是迭代次数
#for i, (x,y) in enumerate(train_loader, 0):
#1 prepare data
#训练之前先把x,y拿出来,dataloader会自动把xy转换成tensor
inputs,labels=data#两个都是张量
#2 forward
y_pred=model(inputs)
loss=criterion(y_pred,labels)
print(epoch,loss.item())
epoch_list.append(epoch)
loss_list.append(loss.item())
#3 backward
optimizer.zero_grad()
loss.backward()
#4 update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
【Pytorch深度学习实践】刘二大人8加载数据集
最新推荐文章于 2024-05-23 08:07:18 发布