import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset#抽象类,不能实例化,只能被继承
from torch.utils.data import DataLoader#可以实例化,加载数据
class DiabetesDataset(Dataset):#继承自Dataset
def __init__(self,filepath):
xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:,[-1]])
def __getitem__(self, index):#魔法方法
return self.x_data[index],self.y_data[index]
def __len__(self):#返回数据集的条数
return self.len
dataset = DiabetesDataset('E:\PyTorch深度学习实践\diabetes.csv.gz')
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.Linear1 = torch.nn.Linear(8,6)
self.Linear2 = torch.nn.Linear(6,4)
self.Linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmoid(self.Linear1(x))#O1
x = self.sigmoid(self.Linear2(x))#O2
x = self.sigmoid(self.Linear3(x))#y_hat
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
train_loader = DataLoader(dataset=dataset,
batch_size=32,shuffle=True,num_workers=2)#读取数据用两个进程
#多进程读取数据报错时,循环不能顶格写,需要缩进
if __name__=='__main__':
for epoch in range(100):
for i,data in enumerate(train_loader,0):
#prepare data
inputs,labels = data
#forward
y_pred = model(inputs)
loss = criterion(y_pred,labels)
print(epoch,i,loss.item())
#backward
optimizer.zero_grad()
loss.backward()
#update
optimizer.step()
加载数据集-pytorch
最新推荐文章于 2024-05-18 16:39:33 发布