Pytorch学习笔记(十 一)——pytorch自定义数据集

一、为什么要使用Datasets类

  Datasets是pytorch的一个类,pytorch自带多种数据集,如:MINIST等数据集就是在pytorch的Datasets的库中的。
  Pytorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。

二、如何定义Datasets?

Dataset类是Pytorch中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:

def getitem(self, index):
def len(self):

其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数
这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

三、实战

数据集的内容组成
在这里插入图片描述

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from torchvision import transforms

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.layer1=nn.Sequential(
            nn.Linear(3,20),
            nn.Sigmoid(),
            nn.Linear(20,40),
            nn.Sigmoid(),
            nn.Linear(40,1)
        )
    def forward(self,x):
        data=x
        data=self.layer1(data)
        return data

class MyDataset(Dataset):
    def __init__(self,root,transform=None):
        super(MyDataset,self).__init__()
        #读取数据,整理读取的x值为一列
        df=pd.read_csv(root,dtype=np.float32)
        #self.data=pd.DataFrame(columns=['data','label'])
        data=[] #用于获取3个x值并组合为一列
        label=[] #用于获取标签值
        self.data=[]
        self.label=[]
        for i in range(df.shape[0]):
            x=df.loc[i] #type:Series
            data.append([x['x1'],x['x2'],x['x3']])
            label.append(x['y'])
        #self.data['data']=data
        #self.data['label']=label
        self.data=data
        self.label=label
        self.transform=transform



    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        x=self.data[item]
        label=self.label[item]
        if self.transform is not None:
            x=self.transform(x)
        return x,label

class ToTensor(object):
    def __call__(self, seq):
        #print(seq.shape)
        return torch.tensor(seq,dtype=torch.float)

if __name__=='__main__':
    path = 'C:/Users/Mr.Li\Desktop/test project/train.csv'
    set=MyDataset(path,ToTensor())
    data=torch.utils.data.DataLoader(dataset=set,batch_size=6,shuffle=True)
    model=Model()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    loss_func = torch.nn.MSELoss()

    for epoch in range(100):
        for i,( x,label) in enumerate(data):
            y=model(x)
            z=label.view(-1,1)
            loss = loss_func(y, z)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(loss)






  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

酒与花生米

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值