利用自己数据集来训练神经网络pytorch,重写Dataset类

本文介绍如何使用PyTorch构建并训练一个针对蚂蚁和蜜蜂数据集的深度学习模型。代码展示了如何创建自定义数据集类`Mydataset`,包括图像预处理、数据加载器的设置,以及如何将数据传入模型进行训练。适合初学者参考实践。
摘要由CSDN通过智能技术生成

很多小伙伴在刚刚结束深度学习算法的时候,肯定想用自己的数据来进行训练网络,但是不知到怎么写代码,下面这个代码就会为你解惑,自己可以根据实际情况来更改代码,训练自己的图片数据集。

下面我用蚂蚁和蜜蜂数据集为例,我的数据格式是这样的,如下图:

 每个类别都会有相应的图片


from torch.utils.data import Dataset,DataLoader
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([transforms.Resize([500, 500]),                                          # 图像预处理
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
                           ])
class_list={0:"ants",1:"bees"}#用于后续预测的时候可以使用,用预测到的标签来直接获取相应的类别
class Mydataset(Dataset):
    def __init__(self,file_path="D:/PycharmProjects/pythonProject/classification-pytorch-main1/datasets",formate="train",transform=False):
        self.transforms=transform
        self.file_path=file_path
        self.formate=formate
        self.file_train=os.path.join(self.file_path,self.formate)
        print(self.file_train)
        files_class = os.listdir(self.file_train)
        self.imgs=[]
        for i, j in enumerate(files_class):
            data = os.path.join(self.file_train, j)
            print(data)
            data_1 = os.listdir(data)
            data_all = [[os.path.join(data, k), i] for k in data_1]
            self.imgs += data_all
        print(self.imgs)
    def __len__(self):
        return len(self.imgs)
    def __getitem__(self, index):
        img_path, label = self.imgs[index]                                                    # 选择文件路径
        pil_img = Image.open(img_path).convert('RGB')                                         # 利用PIL打开文件路径
        if self.transforms:
            img=transform(pil_img)
        else:
            pil_img = np.asarray(pil_img)
            img = torch.from_numpy(pil_img)
        return img, label

if __name__ == '__main__':
    train_data= Mydataset(transform=True)
    print(train_data.__getitem__(0)[0])
    print(train_data.__getitem__(0)[1])
    #验证能否传进模型中
    train_dataloder=DataLoader(train_data,batch_size=8,shuffle=True)
    for data in train_dataloder:
        print(data[0].shape)
        print(data[1])
        #结果不唯一,其中的结果如下:
        # torch.Size([8, 3, 500, 500])
        # tensor([1, 0, 0, 1, 1, 0, 1, 1])
        break

如果觉得有帮助,就点个赞吧,祝大家学业有成!

  • 3
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值