目录
1. torchvision库
torchvision包含了目前流行的数据集,模型结构和常用的图片转换工具。
结合前边的内容学习:
Pytorch 学习第【1】天
Pytorch 学习第【2】天
Pytorch 学习第【3】天
1.1. datasets数据集
torchvision.datasets中包含了以下数据集
图像数据集、文本数据集、音频数据集
1.1.1 加载数据集
例子:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
#训练集
training_data = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=ToTensor()
)
#测试集
test_data = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=ToTensor()
)
【PS:若是下载太慢,可以复制给出的网址到迅雷下载,之后放在指定的目录下,再运行代码回自行解压。显示Files already downloaded and verified就ok了,查看CIFAR10数据集】
datasets参数解析
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
)
其中
root:是存储训练/测试数据的路径,
train:指定训练或测试数据集,
download=True(如果无法在root路径上获得,从互联网下载数据)
transform和target_transform并指定特征和标签转换
1.1.2 查看数据集
导入的库:
from torchvision.transforms import *
import matplotlib.pyplot as plt
from torchvision import *
随机显示部分图片并标以对应的标签:
# 将tensor 格式转化为PIL格式
to_pil_image = transforms.ToPILImage()
# 标签对应的含义
labels_map = {
0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck",
}
#图片显示大小
figure = plt.figure(figsize=(8, 8))
#这里显示是三行三列
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
# 随机选择一张图片
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
# 将tensor类型转成PIL类型
img = to_pil_image(img)
#准备画图
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
#不显示坐标轴
plt.axis("off")
plt.imshow(img,)
plt.show()
【PS:需要注意的导入数据的时候是tensor类型的,在显示图片需要将tensor类型转化成PIL类型。】
当然也可以通过tensorBoard进行查看
writer = SummaryWriter("p9")
for i in range(9):
img,label = training_data[i]
writer.add_image("train_set",img,i)
writer.close()
【PS:命令忘记了可以看Pytorch 学习第【3】天】
1.1.3 输出结果
在tensorboard上:
1.2 DataLoader数据加载器
数据加载程序。结合一个数据集和一个采样器,并在给定数据集上提供一个可迭代对象。DataLoader支持地图样式和可迭代样式的数据集,支持单进程或多进程加载、自定义加载顺序、可选的自动批处理(排序)和内存固定。
参数:
CLASStorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
其中比较常设置的参数
dataset (dataset):用于加载数据的数据集。
Batch_size (int,可选):每批加载多少个样本(默认值:1)。
shuffle (bool,可选):设置为True,在每个epoch重新洗牌数据(认:False)。
sampler (sampler或Iterable,可选):定义了从数据集中抽取样本的策略。
Num_workers (int,可选):用于数据加载的子进程数。0表示数据将加载到主进程中。(默认值:0)
drop_last (bool,可选):设置为True表示删除最后一个不完整的批处理,如果数据集大小不能被批处理大小整除。如果为False且数据集的大小不能被批量大小整除,则最后一批数据将更小。(默认值:False)
1.2.1 tensorboard上进行查看
代码:
# 从训练集中,batch_size=64每批加载64个样本,在每个epoch重新洗牌数据,
# 加载到主进程中,保留不能整除的部分
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True,num_workers=0,drop_last=False)
writer =SummaryWriter("dataloader")
i=0
for data in test_dataloader:
imgs, targets = data
print(imgs.shape)
print(targets)
writer.add_images("test_dataloader",imgs,i)
i+=1
writer.close()
结果:
【PS:因为drop_last = false,将保留不能除尽的部分在tensorboard的显示效果,同时这里使用的是add_images。若drop_last = true将删除最后一个不完整的批处理,可自行尝试】
本文作者:九重!
本文链接:https://blog.csdn.net/weixin_43798572/article/details/124237002
关于博主:评论和私信会在第一时间回复。或者直接私信我。
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【点赞】【收藏】一下。您的鼓励是博主的最大动力!