第八讲加载数据集

第八讲加载数据集的笔记和源码

B站 刘二大人 教学视频 传送门:加载数据集

简单加载知识:

1、首先Dataset是个抽象的函数,它不能直接进行实例化,所以要创建一个自己的类,继承Dataset。

2、__init__() 是初始化函数,之后我们可以提供数据集路径进行数据的加载

3、__getitem()__() 用于帮助我们通过索引来找到某个样本

batch_size 每一组有多少个样本,shuffle=True,意思是数据打乱

 

4、__len__() 帮助我们返回数据集的大小

5、我补充一个知识点:在前面BCELoss中 reduction 和 size_average的用法:

       reduction的参数 mean 表示损失均值  sum表示损失求和

        size_aveage的参数True 表示返回损失均值 False表示返回损失求和


本讲源码:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

# Dataset是一个抽象函数,不能直接实例化,所以我们要创建一个自己类,继承Dataset
# 继承Dataset后我们必须实现三个函数:
# __init__()是初始化函数,之后我们可以提供数据集路径进行数据的加载
# __getitem__()帮助我们通过索引找到某个样本
# __len__()帮助我们返回数据集大小

#糖尿病数据集分类
class DiabetesDataset(Dataset):
    def __init__(self,filepath):
        xy=np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len=xy.shape[0] #矩阵行数,[1]为列数
        self.x_data=torch.from_numpy(xy[:,:-1])
        self.y_data=torch.from_numpy(xy[:,[-1]])

    def __getitem__(self, index):
        return self.x_data[index],self.y_data[index]

    def __len__(self):
        return self.len

dataset=DiabetesDataset('diabetes.csv.gz')
#DataLoader, dataset数据 batch_size每次训练多少即每组有多少个样本 
#Shuffle是否打乱 num_works几个并行进程
train_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model , self).__init__()
        #输入8维 输出1维
        self.linear1=torch.nn.Linear(8,6)
        self.linear2=torch.nn.Linear(6,4)
        self.linear3=torch.nn.Linear(4,1)
        self.sigmoid=torch.nn.Sigmoid()
    def forward(self,x):
        x=self.sigmoid(self.linear1(x))
        x=self.sigmoid(self.linear2(x))
        x=self.sigmoid(self.linear3(x))
        return x
model=Model()

criterion=torch.nn.BCELoss(reduction='mean')
optimizer=optim.SGD(model.parameters(),lr=0.01)


for epoch in range(100):
    for i,data in enumerate(train_loader,0): #取出一个bath
        #1.prepare data
        inputs,labels=data  #将输入的数据赋给inputs,结果赋给labels
        #2.Forward
        y_pred=model(inputs)
        loss=criterion(y_pred,labels)
        print(epoch,i,loss.item())
        #3.Backward
        optimizer.zero_grad()
        loss.backward()
        #4.update
        optimizer.step()

结果部分展示:

98 22 0.6628814935684204
98 23 0.6179142594337463
99 0 0.6623567342758179
99 1 0.702293872833252
99 2 0.6040289998054504
99 3 0.6433210372924805
99 4 0.524371862411499
99 5 0.6624910831451416
99 6 0.6632858514785767
99 7 0.6229968070983887
99 8 0.6228879690170288
99 9 0.5628733038902283
99 10 0.5231003761291504
99 11 0.7231881618499756
99 12 0.6430286765098572
99 13 0.6830095052719116
99 14 0.6229137182235718
99 15 0.683765709400177
99 16 0.6428204774856567
99 17 0.6830291152000427
99 18 0.6825512051582336
99 19 0.6831673979759216
99 20 0.7024272084236145
99 21 0.6226396560668945
99 22 0.643019437789917
99 23 0.6727415323257446


进程已结束,退出代码为 0


完结,如有任何错误,敬请指正,非常感谢!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值