Pytorch加载自己的数据集(以图片格式的Mnist数据集为例)


Pytorch加载自己的数据集(以图片格式的Mnist数据集为例)

前言

初学pytorch,看了很多教程,发现所有教程在加载数据集的时候都用的pytorch已经定义好的模块,没有详细讲到如何使用Dataset和DataLoader加载自己格式多样的数据集,经过一段时间研究,成功跑通以图片为训练数据集的简单分类模型,现记录如下。

数据集在这里:
链接: https://pan.baidu.com/s/16T1IoAgOsepLqFRzjDck3g?pwd=h254 提取码: h254 复制这段内容后打开百度网盘手机App,操作更方便哦


一、数据集转换

Mnist是非常经典的数据集之一,从官网下载得到的是二进制的文件,与我们常用的图片格式不符,所以先将二进制文件转换为图像。
在这里插入图片描述

转换代码如下:

# -*- coding: utf-8 -*-
import numpy as np
import struct
import os
import cv2


class DataUtils(object):
    def __init__(self, filename=None, outpath=None):
        self._filename = filename
        self._outpath = outpath

        self._tag = '>'  # 大端格式
        self._twoBytes = 'II'
        self._fourBytes = 'IIII'
        self._pictureBytes = '784B'
        self._labelByte = '1B'
        self._twoBytes2 = self._tag + self._twoBytes
        self._fourBytes2 = self._tag + self._fourBytes
        self._pictureBytes2 = self._tag + self._pictureBytes
        self._labelByte2 = self._tag + self._labelByte

        self._imgNums = 0
        self._LabelNums = 0

    def getImage(self):
        """
        将MNIST的二进制文件转换成像素特征数据
        """
        binfile = open(self._filename, 'rb')  # 以二进制方式打开文件
        buf = binfile.read()
        binfile.close()
        index = 0
        numMagic, self._imgNums, numRows, numCols = struct.unpack_from(self._fourBytes2, buf, index)
        index += struct.calcsize(self._fourBytes)
        images = []
        print('image nums: %d' % self._imgNums)
        for i in range(self._imgNums):
            imgVal = struct.unpack_from(self._pictureBytes2, buf, index)
            index += struct.calcsize(self._pictureBytes2)
            imgVal = list(imgVal)
            images.append(imgVal)
        return np.array(images), self._imgNums

    def getLabel(self):
        """
        将MNIST中label二进制文件转换成对应的label数字特征
        """
        binFile = open(self._filename, 'rb')
        buf = binFile.read()
        binFile.close()
        index = 0
        magic, self._LabelNums = struct.unpack_from(self._twoBytes2, buf, index)
        index += struct.calcsize(self._twoBytes2)
        labels = []
        for x in range(self._LabelNums):
            im = struct.unpack_from(self._labelByte2, buf, index)
            index += struct.calcsize(self._labelByte2)
            labels.append(im[0])
        return np.array(labels)

    def outImg(self, arrX, arrY, imgNums):
        """
        根据生成的特征和数字标号,输出图像
        """
        output_txt = self._outpath + '/img.txt'
        output_file = open(output_txt, 'a+')

        m, n = np.shape(arrX)
        # 每张图是28*28=784Byte
        for i in range(imgNums):
            img = np.array(arrX[i])
            img = img.reshape(28, 28)
            # print(img)
            outfile = str(i) + "_" + str(arrY[i]) + ".bmp"
            # print('saving file: %s' % outfile)

            txt_line = outfile + " " + str(arrY[i]) + '\n'
            output_file.write(txt_line)
            cv2.imwrite(self._outpath + '/' + outfile, img)
        output_file.close()


if __name__ == '__main__':
    # 二进制文件路径,需要修改,和自己的相对应
    trainfile_X = 'C:\\Users\\60058670\\Desktop\\MNIST\\train-images.idx3-ubyte'
    trainfile_y = 'C:\\Users\\60058670\\Desktop\\MNIST\\train-labels.idx1-ubyte'
    testfile_X = 'C:\\Users\\60058670\\Desktop\\MNIST\\t10k-images.idx3-ubyte'
    testfile_y = 'C:\\Users\\60058670\\Desktop\\MNIST\\t10k-labels.idx1-ubyte'

    # 加载mnist数据集
    train_X, train_img_nums = DataUtils(filename=trainfile_X).getImage()
    train_y = DataUtils(filename=trainfile_y).getLabel()
    test_X, test_img_nums = DataUtils(testfile_X).getImage()
    test_y = DataUtils(testfile_y).getLabel()

    # 以下内容是将图像保存到本地文件中
    path_trainset = "C:\\Users\\60058670\\Desktop\\MNIST\\train"
    path_testset = "C:\\Users\\60058670\\Desktop\\MNIST\\test"
    if not os.path.exists(path_trainset):
        os.mkdir(path_trainset)
    if not os.path.exists(path_testset):
        os.mkdir(path_testset)
    DataUtils(outpath=path_trainset).outImg(train_X, train_y, int(train_img_nums / 10))  # /10是只转换十分之一,用于测试
    DataUtils(outpath=path_testset).outImg(test_X, test_y, int(test_img_nums / 10))

二、构建自己的数据集

构建方法为继承Dataset类,用DataLoader加载

1.引入库

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

2.构建MnistDataset类

# 构建自己的数据集
class MnistDataset(Dataset):
    def __init__(self, transform=None, lu_jing=None):
        self.lu_jing = lu_jing
        self.数据 = os.listdir(self.lu_jing)
        self.transform = transform
        self.len = len(self.数据)

    def __getitem__(self, index):
        image_index = self.数据[index]
        img_path = os.path.join(self.lu_jing, image_index)
        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)

        label = int(image_index[-5])
        label = self.oneHot(label)
        return img, label

    def __len__(self):
        return self.len

    # 将标签转为onehot编码
    def oneHot(self, label):
        tem = np.zeros(10)
        tem[label] = 1
        return torch.from_numpy(tem)

3.搭建网络模型

只为演示,模型比较简单。

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Conv1 = torch.nn.Conv2d(1, 10, kernel_size=(5, 5))
        self.Conv2 = torch.nn.Conv2d(10, 20, kernel_size=(5, 5))
        self.pool = torch.nn.MaxPool2d(2)
        self.fl = torch.nn.Linear(320, 10)

    def forward(self, x):
        bs = x.size(0)
        x = F.relu(self.pool(self.Conv1(x)))
        x = F.relu(self.pool(self.Conv2(x)))
        x = x.view(bs, -1)
        x = self.fl(x)
        return x

三 完整代码

import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from torchvision import transforms
import torch.nn.functional as F


# 构建自己的数据集
class MnistDataset(Dataset):
    def __init__(self, transform=None, lu_jing=None):
        self.lu_jing = lu_jing
        self.数据 = os.listdir(self.lu_jing)
        self.transform = transform
        self.len = len(self.数据)

    def __getitem__(self, index):
        image_index = self.数据[index]
        img_path = os.path.join(self.lu_jing, image_index)
        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)

        label = int(image_index[-5])
        label = self.oneHot(label)
        return img, label

    def __len__(self):
        return self.len

    # 将标签转为onehot编码
    def oneHot(self, label):
        tem = np.zeros(10)
        tem[label] = 1
        return torch.from_numpy(tem)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Conv1 = torch.nn.Conv2d(1, 10, kernel_size=(5, 5))
        self.Conv2 = torch.nn.Conv2d(10, 20, kernel_size=(5, 5))
        self.pool = torch.nn.MaxPool2d(2)
        self.fl = torch.nn.Linear(320, 10)

    def forward(self, x):
        bs = x.size(0)
        x = F.relu(self.pool(self.Conv1(x)))
        x = F.relu(self.pool(self.Conv2(x)))
        x = x.view(bs, -1)
        x = self.fl(x)
        return x


if __name__ == '__main__':
    # 训练集路径
    train_data = "C:\\Users\\60058670\\Desktop\\MNIST\\train"
    transform = transforms.Compose([transforms.ToTensor()])  # 归一化处理
    data = MnistDataset(transform=transform, lu_jing=train_data)
    data_loader = DataLoader(data, batch_size=200, shuffle=True)  # 使用DataLoader加载数据
    model = Model()
    criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  # model.parameters()自动完成参数的初始化操作

    for epoch in range(20):
        for i, data1 in enumerate(data_loader, 0):  # train_loader 是先shuffle后mini_batch
            inputs, labels = data1
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        if epoch % 5 == 0:
            print(epoch, loss.item())
    # 测试集路径
    test_data = 'C:\\Users\\60058670\\Desktop\\MNIST\\test'

    x_test = MnistDataset(transform=transform, lu_jing=test_data)
    x_test = DataLoader(x_test, batch_size=100, shuffle=False)  # 使用DataLoader加载数据
    total = 0
    correct = 0
    for i, data in enumerate(x_test, 0):  # train_loader 是先shuffle后mini_batch
        inputs, labels = data
        y_pred = model(inputs)
        _, labels = torch.max(labels.data, dim=1)
        _, predicted = torch.max(y_pred.data, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print('accuracy on test set: {} % '.format(100 * correct / total))
    print(correct, total)

总结

纸上得来终觉浅,绝知此事要躬行。自己动手写了代码就会发现一堆问题,知识就是在解决问题的过程中积累的。初学不久,有问题大家可以一起交流讨论。

  • 4
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
Pytorch加载图片数据集一般有两种方法。第一种是使用torchvision.datasets的ImageFolder来读取图片,然后用DataLoader来并行加载,适合图片分类问题,简单但不灵活。\[1\]您可以通过设置各种参数,例如批处理大小以及是否在每个epoch之后对数据打乱顺序,来自定义DataLoader。例如,可以使用以下代码创建一个DataLoader:dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)。\[2\]第二种方法是通过继承torch.utils.data.Dataset实现用户自定义读取数据集,然后用DataLoader来并行加载,这种方法更为灵活。您可以将分类图片的父目录作为路径传递给ImageFolder(),并传入transform来加载数据集。然后可以使用DataLoader加载数据,并构建网络训练。\[3\] #### 引用[.reference_title] - *1* [Pytorch加载图片数据集的两种方式](https://blog.csdn.net/weixin_43917574/article/details/114625616)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [Pytorch加载图像数据](https://blog.csdn.net/qq_28368377/article/details/105635898)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [pytorch加载自己的图片数据集的两种方法](https://blog.csdn.net/qq_53345829/article/details/124308515)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值