【pytorch】数据读取

构建batch的函数接口是torch.utils.data.DataLoader,通过这个接口,可以实现迭代式地给出一个一个batch数据,送入模型。

DataLoader需要提供四个比较关键的参数:dataset,batch_size,drop_last,collate_fn。其中dataset可以是torch.utils.data.Dataset的子类,也可以是torch.utils.data.TensorDataset,等等。最常见的是torch.utils.data.Dataset的子类,可以自由构造各种格式的样本,需要重写父类中的__init__(),还有__getitem__()。batch_size,drop_last不用解释,collate_fn就比较有意思,是对dataset每一次__getitem__()返回的结果再做格式整合。

具体看个例子:

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([i for i in range(12)])
inputing = torch.tensor([test[i:i+3] for i in range(10)])
target = torch.tensor([test[i:i+1] for i in range(10)])
torch_dataset = Data.TensorDataset(inputing, target)

for x in torch_dataset:
    print(x)

"""
TensorDataset的作用是把inputing和target组合在一起,像zip一样

result:
(tensor([0, 1, 2], dtype=torch.int32), tensor([0], dtype=torch.int32))
(tensor([1, 2, 3], dtype=torch.int32), tensor([1], dtype=torch.int32))
(tensor([2, 3, 4], dtype=torch.int32), tensor([2], dtype=torch.int32))
(tensor([3, 4, 5], dtype=torch.int32), tensor([3], dtype=torch.int32))
(tensor([4, 5, 6], dtype=torch.int32), tensor([4], dtype=torch.int32))
(tensor([5, 6, 7], dtype=torch.int32), tensor([5], dtype=torch.int32))
(tensor([6, 7, 8], dtype=torch.int32), tensor([6], dtype=torch.int32))
(tensor([7, 8, 9], dtype=torch.int32), tensor([7], dtype=torch.int32))
(tensor([ 8,  9, 10], dtype=torch.int32), tensor([8], dtype=torch.int32))
(tensor([ 9, 10, 11], dtype=torch.int32), tensor([9], dtype=torch.int32))

"""


loader = Data.DataLoader(dataset=torch_dataset, batch_size=3, drop_last=True,
collate_fn=lambda x:x)

for i in loader:
    print(i)
    print("-------")

"""
此时collate_fn不起任何修饰作用,loader返回batch,是一个list,list大小为batch_size,每一个list中的元素为tuple,tuple中的两个元素都为tensor,对应开头的inputing和target。

result:
[(tensor([0, 1, 2], dtype=torch.int32), tensor([0], dtype=torch.int32)), 
(tensor([1, 2, 3], dtype=torch.int32), tensor([1], dtype=torch.int32)), 
(tensor([2, 3, 4], dtype=torch.int32), tensor([2], dtype=torch.int32))]
-------
[(tensor([3, 4, 5], dtype=torch.int32), tensor([3], dtype=torch.int32)), 
(tensor([4, 5, 6], dtype=torch.int32), tensor([4], dtype=torch.int32)),
 (tensor([5, 6, 7], dtype=torch.int32), tensor([5], dtype=torch.int32))]
-------
[(tensor([6, 7, 8], dtype=torch.int32), tensor([6], dtype=torch.int32)), 
(tensor([7, 8, 9], dtype=torch.int32), tensor([7], dtype=torch.int32)), 
(tensor([ 8,  9, 10], dtype=torch.int32), tensor([8], dtype=torch.int32))]
-------

"""


collate_fn = lambda x: [
    torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0).unsqueeze(0)
    for j in range(len(x[0]))
]

loader_2 = Data.DataLoader(dataset=torch_dataset, batch_size=3, drop_last=True,
collate_fn=collate_fn)

for i in loader_2 :
    print(i)
    print("-------")

"""
collate_fn相当于让batch中的inputing和target各自聚合起来。这种两重for的写法非常pythonic,
x[i][j]...for i in range(...) 这里只说明了i,相当于引入一个未知量j。留到后面说明就行,比如这里是x[i][j]...for i in range(...)处理完,再for j in range(...)

result:
[tensor([[[0, 1, 2],
         [1, 2, 3],
         [2, 3, 4]]], dtype=torch.int32), 
tensor([[[0],
         [1],
         [2]]], dtype=torch.int32)]
-------
[tensor([[[3, 4, 5],
         [4, 5, 6],
         [5, 6, 7]]], dtype=torch.int32), 
tensor([[[3],
         [4],
         [5]]], dtype=torch.int32)]
-------
[tensor([[[ 6,  7,  8],
         [ 7,  8,  9],
         [ 8,  9, 10]]], dtype=torch.int32), 
tensor([[[6],
         [7],
         [8]]], dtype=torch.int32)]
-------

"""

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值