test 1 关于NeighborSampler
1. 备注:pyg中有类torch_geometric.data.Dataloader(pyg1.7.2以前版本,pyg2.0.0以后版本是torch_geometric.loader.DataLoader),同时torch_geometric.data.Dataloader是torch.DataLoader的子类。
2. neighborsampler是torch.utils.data.DataLoader的子类。
# 关于NeighborSampler的代码示例
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import NeighborSampler
row_data = np.load("../data/Reddit_coo/edge/col.npy", allow_pickle = True)
col_data = np.load("../data/Reddit_coo/edge/row.npy", allow_pickle = True)
indices = np.vstack((row_data, col_data))
# device = torch.device("cpu")
adj = torch.LongTensor(indices)
#adj = adj.cuda()
train_loader = NeighborSampler(adj,
#node_idx=train_idx,
sizes=[5], batch_size=1024,
shuffle=True, num_workers=1)
2. 关于NeighborSampler的源码:
pyg官网中1.7.2版本关于NeighborSampler源码
以上面的代码为例,关于这段源码的一些备注。
1.
adj.max()
out[1] : tensor(232964)
adj.size(0)
adj.size(1)
out[2]:2
out[3]:114615892
value = torch.arange(adj.size(1))
out[1]:tensor([ 0, 1, 2, ..., 114615889, 114615890, 114615891])
from torch_sparse import SparseTensor
self.adj_t = SparseTensor(row=adj[0], col=adj[1],
value=value,
sparse_sizes=(num_nodes, num_nodes)).t()
①其中,SparseTensor().t()表示转置。
SparseTensor为torch_sparse包里在.tensor文件定义的类。
②pytorch中.t()与.T都表示转置。查阅官方文档,发现 .t() 是 .transpose函数的简写版本,只能对2维以下的tensor进行转置。.transpose函数对一个n维tensor交换其任意两个维度的顺序。
而 .T 是 .permute 函数的简化版本,不仅可以操作2维tensor,甚至可以对n维tensor进行转置。当然当维数n=2时,.t() 与 .T 效果是一样的。
③关于Pytorch中contiguous()函数
在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。
会改变元数据的操作是:
- narrow()
- view()
- expand()
- transpose()
在使用transpose()
进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。转置的tensor和原tensor的内存是共享的。
而如果想要断开这两个变量之间的依赖(x本身是contiguous的),就要使用contiguous()针对x进行变化,感觉上就是我们认为的深拷贝。
当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一模一样,但是两个tensor完全没有联系。
本段引用自: Pytorch中contiguous()函数理解_.contiguous()_清晨的光明的博客-CSDN博客
# torch_sparse/storage.py中
rowptr = torch.ops.torch_sparse.ind2ptr(row, self._sparse_sizes[0])
# 节选自:
def rowptr(self) -> torch.Tensor:
rowptr = self._rowptr
if rowptr is not None:
return rowptr
row = self._row
if row is not None:
rowptr = torch.ops.torch_sparse.ind2ptr(row, self._sparse_sizes[0])
self._rowptr = rowptr
return rowptr
raise ValueError
此段存疑。不过猜测是用来获取row[ ]中各个点的指针的。
node_idx.view(-1).tolist()
关于view的解释,可参考: tensor.view(*shape) 函数_tensor view函数_三世的博客-CSDN博客
view( -1 )即转换维度为一维。实际使用中常看见data.contiguous().view(-1)的用法,contiguous()
就是为了保证一个Tensor是连续的,这样才能被 view() 处理。
2. 对于NeighborSampler类的__init__,最后
super(NeighborSampler, self).__init__(
node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)
其父类为torch.utils.data.DataLoader。
3. 理解Pytorch中torch.utils.data
以PyTorch 1.7.2版本为准
迭代器
在 Dataset
, Sampler
和 DataLoader
这三个类中都会用到 python 抽象类的方法,包括__len__(self)
,__getitem__(self)
和 __iter__(self)
__len__(self)
: 定义当被len()
函数调用时的行为,一般返回迭代器中元素的个数__getitem__(self)
: 定义获取容器中指定元素时的行为,相当于self[key]
,即允许类对象拥有索引操作__iter__(self)
: 定义当迭代容器中的元素时的行为
每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典,这些数据结构都支持迭代操作。
实现迭代器的魔法方法有两个:__iter__(self)
和 __next__(self)
一个容器如果是迭代器,那就必须实现 __iter__(self)
方法,这个方法实际上是返回是一个迭代器(通常是迭代器本身)。接下来重点要实现的是 __next__(self)
方法,因为它决定了迭代的规则。
一般来说,迭代器满足以下几种特性:
- 迭代器是⼀个对象
- 迭代器可以被 next() 函数调⽤,并返回⼀个值
- 迭代器可以被 iter() 函数调⽤,并返回一个迭代器(可以是自身)
- 连续被 next() 调⽤时依次返回⼀系列的值
- 如果到了迭代的末尾,则抛出 StopIteration 异常
- 迭代器也可以没有末尾,只要被 next() 调⽤,就⼀定会返回⼀个值
- Python 中, next() 内置函数调⽤的是对象的 next() ⽅法
- Python 中, iter() 内置函数调⽤的是对象的 iter() ⽅法
- ⼀个实现了迭代器协议的的对象可以被 for 语句循环迭代直到终⽌
3.1 DataSet
DataLoader constructor最重要的参数,指明了一个dataset加载data的来源。
Dataset 负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。
Dataset 共有 Map-style datasets 和 Iterable-style datasets 两种。
Map-style datasets
表示从(可能是非整数)索引/关键字到数据样本的映射,
attribute | meaning | default value | type |
dataset | 加载数据的数据集 | Dataset | |
shuffle | 设置为 True 时,调用 RandomSampler 进行随机索引 | False | bool |
sampler | 定义从数据集中提取样本的策略 如果指定了, shuffle 参数必须为 False,(否则会和 RandomSampler 互斥) | None | Sampler, Iterable |
sampler | 定义从数据集中提取样本的策略 如果指定了, shuffle 参数必须为 False,(否则会和 RandomSampler 互斥) | None | Sampler, Iterable |
num_workers | 要用于数据加载的子进程数,0 表示将在主进程中加载数据 | 0 | int |
collate_fn | 在将 Map-style datase t 取出的数据整合成 batch 时使用,合并样本列表以形成一个 batch | None | callable |
pin_memory | 如果为 True,则 DataLoader 在将张量返回之前将其复制到 CUDA 固定的内存中 | False | bool |
drop_last | 设置为 True 删除最后一个不完整的批次,如果该数据集大小不能被该批次大小整除。 如果 False 并且数据集的大小不能被批次大小整除,那么最后一批将较小 | False | bool |
timeout | 如果为正,则为从 worker 收集 batch 的超时值,应始终为非负数 超过这个时间还没读取到数据的话就会报错 | 0 | numeric |
worker_init_fn | 如果不为 None,它将会被每个 worker 子进程调用, 以 worker id ([0, num_workers - 1] 内的整形) 为输入 | None | callable |
prefetch_factor | 每个 worker 提前加载 的 sample 数量 | 2 | int |
persistent_workers | 如果为 True,dataloader 将不会终止 worker 进程,直到 dataset 迭代完成 | False | bool |
batch_sampler | 和 sampler 类似,但是一般传入 BatchSampler,每次返回一个 batch 大小的索引 其和 batch_size, shuffle 等参数是互斥的 | None | Sampler, Iterable |
3.2-3.3 Sampler 和 Dataloader写了半天没保存,心态炸了,等调整好回来编辑。。。
4.1 三者关系
本段引用自:PyTorch 源码解读之 torch.utils.data:解析数据处理全流程 - 掘金 (juejin.cn)
- 设置 Dataset,将数据 data source 包装成 Dataset 类,暴露提取接口。
- 设置 Sampler,决定采样方式。我们是能从 Dataset 中提取元素了,还是需要设置 Sampler 告诉程序提取 Dataset 的策略。
- 将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置
shuffle
,batch_size
等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。
总结来说,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。
4.2 批处理
4.2.1 自动批处理(默认)
本段引用自:PyTorch 源码解读之 torch.utils.data:解析数据处理全流程 - 掘金 (juejin.cn)
DataLoader
支持通过参数batch_size
, drop_last
, batch_sampler
,自动地把取出的数据整理 (collate) 成批次样本 (batch)
batch_size
和 drop_last
参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample
参数,一次就生成一个 keys list
在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn
参数将样本列表整理成 batch。
# For Map-style
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
4.2.2 关闭自动批处理
当用户想用 dataset 代码手动处理 batch,或仅加载单个 sample data 时,可将 batch_size
和 batch_sampler
设为 None
, 将关闭自动批处理。此时,由 Dataset
产生的 sample 将会直接被 collate_fn
处理。
# For Map-style
for index in sampler:
yield collate_fn(dataset[index])
# For Iterable-style
for data in iter(dataset):
yield collate_fn(data)
4.2.3 collate_fn
当关闭自动批处理 (automatic batching) 时,collate_fn
作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。
当开启自动批处理 (automatic batching) 时,collate_fn
作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情
- 添加新的批次维度(一般是第一维)
- 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量
- 它保留数据结构,例如,如果每个样本都是
dict
,则输出具有相同键集但批处理过的张量作为值的字典(或list,当不能转换的时候)。list
,tuples
,namedtuples
同样适用
自定义 collate_fn
可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。