PyTorch使用DataLoader自定义加载糖尿病患者数据集


写在前面

本文将承接上文,继续借助糖尿患者病情数据分析的分类案例,使用PyTorch来搭建人工智能神经网络1。主要探讨如何使用DataLoader来自定义数据集类,后面在此基础上做一个小的拓展2,将数据分为训练集和测试集,并在模型的评价指标中引入准确率acc【这是深度学习数学原理专题系列的第六篇文章】


本案例的PyTorch代码实现

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2021/12/1 22:50
# @Author  : William Baker
# @FileName: lesson8_dataset_batch.py
# @Software: PyCharm
# @Blog    : https://blog.csdn.net/weixin_43051346

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class DiabetsDateset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetsDateset('./dataset/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32,
                          shuffle=True, num_workers=2)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()       # 将其看作是网络的一层,而不是简单的函数使用

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))       # y_hat
        return x

model = Model()

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

if __name__ == '__main__':
    epoch_list = []
    loss_list = []
    for epoch in range(1000):
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print('Epoch:{}, iter:{}, Loss:{}'.format(epoch, i, loss.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_list.append(epoch)
        loss_list.append(loss.item())

    plt.plot(epoch_list, loss_list)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()

输出结果为:

Epoch:0, iter:0, Loss:0.5866138935089111
Epoch:0, iter:1, Loss:0.7563868165016174
Epoch:0, iter:2, Loss:0.568118691444397
Epoch:0, iter:3, Loss:0.6052914261817932
Epoch:0, iter:4, Loss:0.6238793134689331
Epoch:0, iter:5, Loss:0.6244214177131653
Epoch:0, iter:6, Loss:0.6615402698516846
Epoch:0, iter:7, Loss:0.6803194284439087
Epoch:0, iter:8, Loss:0.6616236567497253
Epoch:0, iter:9, Loss:0.6055610179901123
Epoch:0, iter:10, Loss:0.6427847743034363
Epoch:0, iter:11, Loss:0.6243879199028015
Epoch:0, iter:12, Loss:0.6052394509315491
Epoch:0, iter:13, Loss:0.6425594091415405
Epoch:0, iter:14, Loss:0.6997544765472412
Epoch:0, iter:15, Loss:0.6616796851158142
Epoch:0, iter:16, Loss:0.6804618835449219
Epoch:0, iter:17, Loss:0.6429725885391235
Epoch:0, iter:18, Loss:0.6803658604621887
Epoch:0, iter:19, Loss:0.7556402087211609
Epoch:0, iter:20, Loss:0.5494366884231567
Epoch:0, iter:21, Loss:0.5864083766937256
Epoch:0, iter:22, Loss:0.6809739470481873
Epoch:0, iter:23, Loss:0.6454336047172546
Epoch:1, iter:0, Loss:0.6619009971618652
Epoch:1, iter:1, Loss:0.6054559350013733
Epoch:1, iter:2, Loss:0.5673891305923462
Epoch:1, iter:3, Loss:0.5676637887954712
Epoch:1, iter:4, Loss:0.6054767370223999
Epoch:1, iter:5, Loss:0.6617451906204224
Epoch:1, iter:6, Loss:0.7186353206634521
Epoch:1, iter:7, Loss:0.6050671339035034
Epoch:1, iter:8, Loss:0.6805665493011475
Epoch:1, iter:9, Loss:0.6429550051689148
Epoch:1, iter:10, Loss:0.6995329260826111
Epoch:1, iter:11, Loss:0.5868749022483826
Epoch:1, iter:12, Loss:0.6609913110733032
Epoch:1, iter:13, Loss:0.736927330493927
Epoch:1, iter:14, Loss:0.6614935398101807
Epoch:1, iter:15, Loss:0.7184151411056519
Epoch:1, iter:16, Loss:0.7373064160346985
Epoch:1, iter:17, Loss:0.6053014993667603
Epoch:1, iter:18, Loss:0.6615592241287231
Epoch:1, iter:19, Loss:0.5683175921440125
Epoch:1, iter:20, Loss:0.6054319739341736
Epoch:1, iter:21, Loss:0.6805724501609802
Epoch:1, iter:22, Loss:0.6430844068527222
Epoch:1, iter:23, Loss:0.5671136379241943
Epoch:2, iter:0, Loss:0.6806985139846802
Epoch:2, iter:1, Loss:0.6431689262390137
Epoch:2, iter:2, Loss:0.6054792404174805
Epoch:2, iter:3, Loss:0.6427459120750427
Epoch:2, iter:4, Loss:0.7186117768287659
Epoch:2, iter:5, Loss:0.6427768468856812
Epoch:2, iter:6, Loss:0.6996884346008301
Epoch:2, iter:7, Loss:0.7186328768730164
Epoch:2, iter:8, Loss:0.5862993001937866
Epoch:2, iter:9, Loss:0.6806195378303528
Epoch:2, iter:10, Loss:0.6244773268699646
Epoch:2, iter:11, Loss:0.661752462387085
Epoch:2, iter:12, Loss:0.5865373611450195
Epoch:2, iter:13, Loss:0.6242547035217285
Epoch:2, iter:14, Loss:0.643303632736206
Epoch:2, iter:15, Loss:0.6615922451019287
Epoch:2, iter:16, Loss:0.6052978038787842
Epoch:2, iter:17, Loss:0.7180856466293335
Epoch:2, iter:18, Loss:0.623708963394165
...
Epoch:998, iter:21, Loss:0.7610572576522827
Epoch:998, iter:22, Loss:0.6822502017021179
Epoch:998, iter:23, Loss:0.7005165219306946
Epoch:999, iter:0, Loss:0.6421785354614258
Epoch:999, iter:1, Loss:0.7613226175308228
Epoch:999, iter:2, Loss:0.6625707745552063
Epoch:999, iter:3, Loss:0.6627084612846375
Epoch:999, iter:4, Loss:0.7413778305053711
Epoch:999, iter:5, Loss:0.6616511344909668
Epoch:999, iter:6, Loss:0.6027993559837341
Epoch:999, iter:7, Loss:0.5633026957511902
Epoch:999, iter:8, Loss:0.6229657530784607
Epoch:999, iter:9, Loss:0.6030195951461792
Epoch:999, iter:10, Loss:0.6025324463844299
Epoch:999, iter:11, Loss:0.622697114944458
Epoch:999, iter:12, Loss:0.8208112120628357
Epoch:999, iter:13, Loss:0.6619958281517029
Epoch:999, iter:14, Loss:0.6227477788925171
Epoch:999, iter:15, Loss:0.6026306748390198
Epoch:999, iter:16, Loss:0.6423561573028564
Epoch:999, iter:17, Loss:0.5632070899009705
Epoch:999, iter:18, Loss:0.6031973958015442
Epoch:999, iter:19, Loss:0.6428719758987427
Epoch:999, iter:20, Loss:0.6228582262992859
Epoch:999, iter:21, Loss:0.7022585868835449
Epoch:999, iter:22, Loss:0.6028652191162109
Epoch:999, iter:23, Loss:0.6170942187309265

在这里插入图片描述


拓展

将数据分为训练集测试集,并在模型的评价指标中引入准确率acc

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2021/12/2 6:01
# @Author  : William Baker
# @FileName: lesson8_train_val.py
# @Software: PyCharm
# @Blog    : https://blog.csdn.net/weixin_43051346
# import os
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split


raw_data = np.loadtxt('C:/Users/***/dataset/diabetes.csv.gz', delimiter=',', dtype=np.float32)
X = raw_data[:, :-1]
Y = raw_data[:, [-1]]
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.3)
Xtest = torch.from_numpy(Xtest)
Ytest = torch.from_numpy(Ytest)

class DiabetsDateset(Dataset):
    def __init__(self, data, label):
        # xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = data.shape[0]
        self.x_data = torch.from_numpy(data)
        self.y_data = torch.from_numpy(label)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetsDateset(Xtrain, Ytrain)
train_loader = DataLoader(dataset=dataset, batch_size=32,
                          shuffle=True, num_workers=1)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()       # 将其看作是网络的一层,而不是简单的函数使用

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))       # y_hat
        return x

model = Model()

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

def train(epoch):
    train_loss = 0.0
    count = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        y_pred = model(inputs)
        loss = criterion(y_pred, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        count = i

    if epoch % 20 == 19:
        print("train loss:", train_loss/count, end=',')

def test():
    with torch.no_grad():
        y_pred = model(Xtest)
        y_pred_label = torch.where(y_pred >= 0.5, torch.tensor([1.0]), torch.tensor([0.0]))
        acc = torch.eq(y_pred_label, Ytest).sum().item() / Ytest.size(0)
        print("test acc:", acc)

if __name__ == '__main__':
    for epoch in range(500):
       train(epoch)
       if epoch % 20 == 19:
           test()

输出为:

train loss: 0.6836073733866215,test acc: 0.6403508771929824
train loss: 0.6811290085315704,test acc: 0.6403508771929824
train loss: 0.683714009821415,test acc: 0.6403508771929824
train loss: 0.6818990334868431,test acc: 0.6403508771929824
train loss: 0.6818503625690937,test acc: 0.6403508771929824
train loss: 0.6826709099113941,test acc: 0.6403508771929824
train loss: 0.6799936965107918,test acc: 0.6403508771929824
train loss: 0.6799548231065273,test acc: 0.6403508771929824
train loss: 0.6781436987221241,test acc: 0.6403508771929824
train loss: 0.6833714731037617,test acc: 0.6403508771929824
train loss: 0.6814881190657616,test acc: 0.6403508771929824
train loss: 0.6823583282530308,test acc: 0.6403508771929824
train loss: 0.6779079474508762,test acc: 0.6403508771929824
train loss: 0.6813166737556458,test acc: 0.6403508771929824
train loss: 0.6830350793898106,test acc: 0.6403508771929824
train loss: 0.6811831593513489,test acc: 0.6403508771929824
train loss: 0.6802380681037903,test acc: 0.6403508771929824
train loss: 0.6784024983644485,test acc: 0.6403508771929824
train loss: 0.6809254698455334,test acc: 0.6403508771929824
train loss: 0.6799669452011585,test acc: 0.6403508771929824
train loss: 0.6808101087808609,test acc: 0.6403508771929824
train loss: 0.6806606911122799,test acc: 0.6403508771929824
train loss: 0.6805795542895794,test acc: 0.6403508771929824
train loss: 0.6778338365256786,test acc: 0.6403508771929824
train loss: 0.676879420876503,test acc: 0.6403508771929824

写到这里,差不多本文也就要结束了,如有错误,敬请指正。如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行


参考文章


  1. 《PyTorch深度学习实践》完结合集 - 08.加载数据集
    ↩︎

  2. PyTorch 深度学习实践 第8讲
    ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值