pytorch学习笔记--trochvision.datasets和DataLoader的使用

本文介绍了如何在PyTorch中使用trochvision.datasets加载CIFAR10数据集,并通过DataLoader进行批量处理。示例展示了如何进行数据预处理,以及利用tensorboardX进行可视化。此外,还提供了一个自定义Dataset类`MyDate`的代码示例,用于加载和处理自定义图像数据。
摘要由CSDN通过智能技术生成

trochvision.datasets和DataLoader的使用

本文为学习笔记,感谢PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

一、datasets

datasets工具在trochvision中

import torchvision
from torchvision import transforms as tf
from tensorboardX import SummaryWriter

train_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=True,download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=False,download=True)

print(train_dataset[0]) #(<PIL.Image.Image image mode=RGB size=32x32 at 0x25D73A72E48>, 6)
print(train_dataset[1])#同上,返回一张图和标签组成的元组
print(train_dataset.classes) #查看分类类型,此数据集共10类

writer = SummaryWriter('logs\\2')

#可视化十张图
for i in range(10):
	img ,label = train_dataset[i]
	writer.add_image('10train_img',img,i)
writer.close()

参数:
CIFAR10:是数据集的名字
root=’./dataset’:保存路径
transform=tf.ToTensor():对图片的转变方法
train=True:训练or测试数据
download=True:是否检测下载

二、DataLoader

from torch.utils.data import DataLoader
from torchvision import transforms as tf
import torchvision

test_dataset = torchvision.datasets.CIFAR10(root='./dataset',transform=tf.ToTensor(),train=False,download=True)

#参数batch_size是取数据集中的一个批量进行打包输出,test_iter中的每个元素都是64张图的合并
test_iter = DataLoader(dataset=test_dataset,batch_size=64,shuffle=True,num_workers=0,drop_last=True)

DataLoader中参数batch_size是取数据集中的一个批量进行打包输出,test_iter中的每个元素都是64张图的合并

参数:
dataset:读取的数据集
batch_size :批量大小
shuffle :序列的所有元素随机排序
num_worker :进程数
drop_last :是否丢弃尾部不足batch_size的数据

补充:datasets类的代码

#Dataset类的代码
from torch.utils.data import Dataset
from PIL import Image
import os
# F:\python_project\deep_learning\train\ants_image\0013035.jpg
class MyDate(Dataset):
	def __init__(self,root_dir,label_dir):
		self.root_dir = root_dir
		self.label_dir = label_dir
		self.path = os.path.join(self.root_dir,self.label_dir)
		self.img_path = os.listdir(self.path)
	def __getitem__(self, idx):
		img_name = self.img_path[idx]
		img_value_path = os.path.join(self.root_dir,self.label_dir,img_name)
		img = Image.open(img_value_path)
		label = self.label_dir
		return img,label
	def __len__(self):
		return len(self.img_path)

root_dir = 'F:\python_project\deep_learning\\train'
label_dir = 'ants_image'
ant = MyDate(root_dir,label_dir)
img,label = ant.__getitem__(1)
img.show()
print(label)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值