构建batch的函数接口是torch.utils.data.DataLoader,通过这个接口,可以实现迭代式地给出一个一个batch数据,送入模型。
DataLoader需要提供四个比较关键的参数:dataset,batch_size,drop_last,collate_fn。其中dataset可以是torch.utils.data.Dataset的子类,也可以是torch.utils.data.TensorDataset,等等。最常见的是torch.utils.data.Dataset的子类,可以自由构造各种格式的样本,需要重写父类中的__init__(),还有__getitem__()。batch_size,drop_last不用解释,collate_fn就比较有意思,是对dataset每一次__getitem__()返回的结果再做格式整合。
具体看个例子:
import torch
import torch.utils.data as Data
import numpy as np
test = np.array([i for i in range(12)])
inputing = torch.tensor([test[i:i+3] for i in range(10)])
target = torch.tensor([test[i:i+1] for i in range(10)])
torch_dataset = Data.TensorDataset(inputing, target)
for x in torch_dataset:
print(x)
"""
TensorDataset的作用是把inputing和target组合在一起,像zip一样
result:
(tensor([0, 1, 2], dtype=torch.int32), tensor([0], dtype=torch.int32))
(tensor([1, 2, 3], dtype=torch.int32), tensor([1], dtype=torch.int32))
(tensor([2, 3, 4], dtype=torch.int32), tensor([2], dtype=torch.int32))
(tensor([3, 4, 5], dtype=torch.int32), tensor([3], dtype=torch.int32))
(tensor([4, 5, 6], dtype=torch.int32), tensor([4], dtype=torch.int32))
(tensor([5, 6, 7], dtype=torch.int32), tensor([5], dtype=torch.int32))
(tensor([6, 7, 8], dtype=torch.int32), tensor([6], dtype=torch.int32))
(tensor([7, 8, 9], dtype=torch.int32), tensor([7], dtype=torch.int32))
(tensor([ 8, 9, 10], dtype=torch.int32), tensor([8], dtype=torch.int32))
(tensor([ 9, 10, 11], dtype=torch.int32), tensor([9], dtype=torch.int32))
"""
loader = Data.DataLoader(dataset=torch_dataset, batch_size=3, drop_last=True,
collate_fn=lambda x:x)
for i in loader:
print(i)
print("-------")
"""
此时collate_fn不起任何修饰作用,loader返回batch,是一个list,list大小为batch_size,每一个list中的元素为tuple,tuple中的两个元素都为tensor,对应开头的inputing和target。
result:
[(tensor([0, 1, 2], dtype=torch.int32), tensor([0], dtype=torch.int32)),
(tensor([1, 2, 3], dtype=torch.int32), tensor([1], dtype=torch.int32)),
(tensor([2, 3, 4], dtype=torch.int32), tensor([2], dtype=torch.int32))]
-------
[(tensor([3, 4, 5], dtype=torch.int32), tensor([3], dtype=torch.int32)),
(tensor([4, 5, 6], dtype=torch.int32), tensor([4], dtype=torch.int32)),
(tensor([5, 6, 7], dtype=torch.int32), tensor([5], dtype=torch.int32))]
-------
[(tensor([6, 7, 8], dtype=torch.int32), tensor([6], dtype=torch.int32)),
(tensor([7, 8, 9], dtype=torch.int32), tensor([7], dtype=torch.int32)),
(tensor([ 8, 9, 10], dtype=torch.int32), tensor([8], dtype=torch.int32))]
-------
"""
collate_fn = lambda x: [
torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).unsqueeze(0)
for j in range(len(x[0]))
]
loader_2 = Data.DataLoader(dataset=torch_dataset, batch_size=3, drop_last=True,
collate_fn=collate_fn)
for i in loader_2 :
print(i)
print("-------")
"""
collate_fn相当于让batch中的inputing和target各自聚合起来。这种两重for的写法非常pythonic,
x[i][j]...for i in range(...) 这里只说明了i,相当于引入一个未知量j。留到后面说明就行,比如这里是x[i][j]...for i in range(...)处理完,再for j in range(...)
result:
[tensor([[[0, 1, 2],
[1, 2, 3],
[2, 3, 4]]], dtype=torch.int32),
tensor([[[0],
[1],
[2]]], dtype=torch.int32)]
-------
[tensor([[[3, 4, 5],
[4, 5, 6],
[5, 6, 7]]], dtype=torch.int32),
tensor([[[3],
[4],
[5]]], dtype=torch.int32)]
-------
[tensor([[[ 6, 7, 8],
[ 7, 8, 9],
[ 8, 9, 10]]], dtype=torch.int32),
tensor([[[6],
[7],
[8]]], dtype=torch.int32)]
-------
"""