import torch
from torch.utils.data import Dataset # Dataset抽象类,不可实例化,只能继承
from torch.utils.data import DataLoader # DataLoader 可实例化
import numpy as np
class DiabetesDataset(Dataset):
def __init__(self,filepath):
xy = np.loadtxt(filepath,delimiter=',',dtype = np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:,[-1]])
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
def __len__(self): #返回数据集长度
return self.len
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2) #num_workers 并行读数据
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
def forward(self,x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.sigmoid(self.linear3(x)) #RELU,x小于0时的的y值都是0,算损失时有可能出现ln0
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
'''
enumerate多用于在for循环中得到计数,利用它可以同时获得索引和值,即需要index和value值的时候可以使用enumerate
'''
if __name__ == '__main__':
for epoch in range(100):
for i,data in enumerate(train_loader,0): #0代表从指定索引0开始
print('train_loader:',train_loader)
# 1 Prepare data
inputs,labels = data
# 2 Forward
y_pred = model(inputs)
loss = criterion(y_pred,labels)
print(epoch,i,loss.item())
# 3 Backward
optimizer.zero_grad()
loss.backward()
#4 Update
optimizer.step()
pytorch-08.加载数据集
最新推荐文章于 2024-10-04 17:47:44 发布
该博客展示了如何使用PyTorch构建一个糖尿病数据集的预测模型。首先定义了一个DiabetesDataset类,继承自torch.utils.data.Dataset,加载并处理糖尿病数据。接着,创建了一个简单的神经网络模型,包含几个线性层和激活函数。最后,定义了训练循环,使用DataLoader进行数据加载,并用SGD优化器更新模型参数。
摘要由CSDN通过智能技术生成