原文链接:https://blog.csdn.net/weixin_42028364/article/details/81675021
import torch
import torch.utils.data as Data
import numpy as np
test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])
inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
print(inputing.shape)
print(target.shape)
torch_dataset = Data.TensorDataset(inputing,target)
batch = 3
def co(x):
return (
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
).unsqueeze(0) for j in range(len(x[0]))
)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=batch,
collate_fn=co
)
for (i,j) in loader:
print(i,j)
原文中使用的lamda函数,改写为def形式方便理解。
该函数主要实现功能为:增加一个维度,并返回一个生成器。