首先再开始之前,我想问一下在座的各位知道DataLoader的原理吗,不会的扣1,会的小伙伴们扣脚指头,嘿嘿,开玩笑的,我也不知道,那么接下来我们一起学习如何写吧!
推荐教学视频:4、PyTorch的Dataset与DataLoader详细使用教程_哔哩哔哩_bilibili
5、深入剖析PyTorch DataLoader源码_哔哩哔哩_bilibili(强烈建议,因为孩子也不知道博主讲到哪了,太快了(哭))
推荐学习链接:Pytorch的Dataset和DataLoader详细使用教程_后来后来啊的博客-CSDN博客
pytorch官方示例代码:Datasets & DataLoaders — PyTorch Tutorials 2.0.1+cu117 documentation
官方给出了代码
Iterate through the DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) #创建了一个`train_dataloader`对象,用于加载训练数据。`training_data`是训练数据集,`batch_size=64`表示每个批次的样本数量为64,`shuffle=True`表示在每个epoch开始时对数据进行随机洗牌
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)#创建了一个`test_dataloader`对象,用于加载测试数据。`test_data`是测试数据集,`batch_size=64`和`shuffle=True`的含义与上述相同。
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))#使用`next(iter(train_dataloader))`从`train_dataloader`中获取一个批次的训练特征和标签。`train_features`是训练特征,`train_labels`是对应的训练标签。
print(f"Feature batch shape: {train_features.size()}")#打印出训练特征的形状,`train_features.size()`返回一个元组,表示特征的维度。例如,如果特征是一个3维张量,形状为(64, 1, 28, 28),则打印结果为"Feature batch shape: torch.Size([64, 1, 28, 28])"
print(f"Labels batch shape: {train_labels.size()}")#打印出训练标签的形状,`train_labels.size()`返回一个元组,表示标签的维度。例如,如果标签是一个1维张量,形状为(64,),则打印结果为"Labels batch shape: torch.Size([64])
img = train_features[0].squeeze()#将训练特征中的第一张图片展示出来。`train_features[0]`表示取第一张图片,
label = train_labels[0]# `squeeze()`函数用于去除维度为1的维度,将形状为(1, 28, 28)的张量转换为形状为(28, 28)的张量。`plt.imshow()`函数用于显示图片,`cmap="gray"`表示以灰度图的形式显示
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")#打印出训练标签的值,即该图片对应的标签。例如,如果标签是一个整数,值为5,则打印结果为"Label: 5"
在pycharm中若想查看源码,则只需要输入import torch.utils.data.dataloader,然后按住ctrl+dataloader即可以查看源码
Dataloader相对于dataset,相当于把数据组成一批次,然后在进行处理
以下是部分源码
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler] = None,
batch_sampler: Optional[Sampler[Sequence]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
torch._C._log_api_usage_once("python.data_loader")
#shuffle,在训练结束后随机打乱
#sample:可以自定义在里面取数据,例如想让样本已一种有序的方式取样
#设置了sample和batch_sample就不需要设置shuffle了
#num_workers:多进行读取,一般以0读取
#collate_fn:聚集函数
sampler作用:以某种顺序去dataset里面取元素,样本
iter实际走的就是上述函数,返回self._ger_iteration(),即下图
上函数继承于
next官方讲解:
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
对next的调用就返回了miniset的调用
总结:在pytorch中需要自定义,按照官网教程自定义即可,加载文件,预处理等;在dataset中至少需要实现__init__,__len__,__getitem__三个方法(详细用法可以 查看Pytorch的Dataset和DataLoader详细使用教程_后来后来啊的博客-CSDN博客)
做完后需要放入dataloader中,设计sampler:决定在datase中取样本的 顺序,若不设置设置suffer=True,则会内置实例化一个RandomSampler(把数据库索引中打乱,随机选取),collate_fn:对已经取得进行填充,厚处理(相当于什么都没干,有些文本分类需要使用),该函数以batch作为输入,以batch作为输出;iter方法源码走346行逻辑