1、 Dateset
- 数据的位置
# dataset 提供一种方式取获取数据及其label
# 1、如何获取每一个数据机器lable
# 2、告诉我们总共有多少数据
from torch.utils.data import Dataset
from PIL import Image #目的读取图片
import os
class MyData(Dataset):
#类的实例化,为后面提供全局变量使用
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(root_dir,label_dir)#label的地址,下面有很多同一个label的图片
self.img_path=os.listdir(self.path)#给label写成列表的形式即[0]为第一张图片
#返回图片的属性和所属的标签
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)#每张图片的地址
img=Image.open(img_item_path)
label=self.label_dir
return img,label
#数据集的长度
def __len__(self):
return len(self.img_path)
root_dir=r"pytorch_learning\dataset\train"
daisy_label_dir=r"daisy"
roses_label_dir=r"roses"
daisy_dataset=MyData(root_dir,daisy_label_dir)
roses_dataset=MyData(root_dir,roses_label_dir)
train_dataset=daisy_dataset + roses_dataset
2、torchvision中的数据集的使用
- 这里顺便拿一个图像分类的例子来介绍
- 数据集介绍
- 数据集使用
- 直接下载在线的训练集
- 下载的数据集+transform应用
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set=torchvision.datasets.CIFAR10(root="D:\cv_box\pytorch_learning",transform=dataset_transform,train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="D:\cv_box\pytorch_learning",transform=dataset_transform,train=False,download=True)
writer=SummaryWriter("p10_logs")
for i in range(10):
img,target=test_set[i]
writer.add_image("test_set",img,i)
writer.close()
3.DataLoader使用
- 将数据加载我们的神经网络中(具体怎么加载?)
- 首先就是去看文档
import torchvision
from torch.utils.data import DataLoader
# 数据集准备
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
test_data = torchvision.datasets.CIFAR10(root="D:\cv_box\pytorch_learning",
train=True,
transform=dataset_transform,
download=True
)
# 数据集的加载,====下面的括号里面说的是,一把从test_data中取出4个进行打包,再返回imgs和target
test_loader=DataLoader(dataset=test_data,
batch_size=64,#一把抓四张图片
shuffle=False,#每一轮epoch,本次和上一次是不一样的
num_workers=0,#加载图片使用多少个线程,一般在linux下设置的,windows下默认为0
drop_last=True#设置为“True”将删除最后一个不完整的批处理
)
#测试数据集中第一张图片及target
img,target=test_data[0]
print(img)
print(target)
writer=SummaryWriter("dataloader_logs")
for epoch in range(2):
step=0
for data in test_loader:
imgs,targets=data
# print(imgs.shape)
# print(targets)
writer.add_images("droplast_epoach:{}".format(epoch),imgs,step)#这边是一坨图片
step=step+1;
writer.close()