使用Pytorch_Geometric(PyG)时构建DataLoader,从DataLoader获取样本Batch时报错:RuntimeError: Sizes of tensors must match except in dimension 0.
报错原因是数据对齐错误,1个batch是多个样本的集合,在样本拼接成集合时出现错误,其规律如下:
- 使用pytorch-geometric的dataloader时,batch的各个样本合并规则
- 属性edge_index规则特殊,每个样本edge_index为 2 × e i 2\times e_i 2×ei,则合并n个样本形成一个batch之后的batch.edge_index大小为 2 × ( ∑ i = 1 n e i ) 2\times(\sum_{i=1}^n e_i) 2×(∑i=1nei)
- 其他所有属性如果为tensor,则按照第一个维度扩展,例如对于属性
x
x
x,第一个样本大小为
d
1
×
d
2
d_1\times d_2
d1×d2,第二个样本大小为
d
3
×
d
2
d_3\times d_2
d3×d2,则如果有一个batch包含这两个样本,batch.x的大小会是
(
d
3
+
d
1
)
×
d
2
(d_3+d_1)\times d_2
(d3+d1)×d2。这里一个巨坑,要求除了第一个维度之外,其他维度大小都必须要相同!! 否则会报错
RuntimeError: Sizes of tensors must match except in dimension 0.
- 其他属性如果不是tensor,就会正常按照列表返回,batch.x=[ 样本1的x,样本2的x,样本3的x]
如何解决:
- 如果是使用torch tensor引起的,可以考虑想办法对齐除了第一个维度外,其他维度的宽度。
- 如果没办法对齐,使用非tensor数据类型替换,例如列表。
- 最后的选择,指定batch_size=1以规避。
dataloader=DataLoader(MyData,batch_size=1)
2022/06/23原始
2023/02/20更新
https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
这个是官网更详细的描述,直接看这个简单