Pytorch中的collate_fn、dataloader和dataset

collate_fun输入是batch个sample

DataLoader是拿出batch_size条features和label交给这个函数来collate,所以简单来说,x的最外层是batch个数,x的次外层是特征和label,因此len(x[0])其实是2,以下是一个测试code

import torch.utils.data as Data
import numpy as np
import torch

test = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
# inputting shape is [10,3]
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
# target shape is [10,1]
torch_dataset = Data.TensorDataset(inputing, target)
batch = 3


def collate_fn(x):
    # 关于输入:
    # x的最外层是batch个数
    ## x的次外层是特征和label,因此len(x[0])其实是2
    ### x[i][j].unsqueeze(0)实际是为了torch.cat准备的,或者使用torch.stack也行。torch.stack会增加维度,torch.cat不增加维度
    for j in range(len(x[0])):
        print('tmp1={}'.format([x[i][j].unsqueeze(0) for i in range(len(x))]))
        print('tmp2={}'.format(torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))])))
    # 关于输出:
    # 输出的最外层是个tuple,分别对应features和label
    ## 输出的次外层是tensor,features的shape是[batch_size, embedding_size],label的shape是[batch_size, 1]
    return (
        torch.cat(
            [x[i][j].unsqueeze(0) for i in range(len(x))]
        ) for j in range(len(x[0]))
    )


loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=collate_fn
)
for features, label in loader:
    # print('features={}'.format(features))
    print('features={} label={}'.format(features, label))

也就是说,默认的collate_fn和下面的等价的:

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch
)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=lambda x:(
      torch.cat(
       [x[i][j].unsqueeze(0) for i in range(len(x))]
       ) for j in range(len(x[0]))
    )
)

collate_fun输出是dataloader进行迭代时的一个sample

实际中下面代码有时候不光是features, label这两部分,有时候复杂场景loader中的一个sample会返回一个dict,因此下面的code迭代会出错:

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=collate_fn
)
for features, label in loader:
    # print('features={}'.format(features))
    print('features={} label={}'.format(features, label))

需要改成下面next(iter(loader))的形式或者for里面只用一个item:

import torch.utils.data as Data
import numpy as np
import torch

test = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
# inputting shape is [10,3]
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
# target shape is [10,1]
torch_dataset = Data.TensorDataset(inputing, target)
batch = 3


def collate_fn(x):
    # 关于输入:
    # x的最外层是batch个数
    ## x的次外层是特征和label,因此len(x[0])其实是2
    ### x[i][j].unsqueeze(0)实际是为了torch.cat准备的,或者使用torch.stack也行。torch.stack会增加维度,torch.cat不增加维度
    for j in range(len(x[0])):
        print('tmp1={}'.format([x[i][j].unsqueeze(0) for i in range(len(x))]))
        print('tmp2={}'.format(torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))])))
    # 关于输出:
    # 输出的最外层是个tuple,分别对应features和label
    ## 输出的次外层是tensor,features的shape是[batch_size, embedding_size],label的shape是[batch_size, 1]
    return (
        torch.cat(
            [x[i][j].unsqueeze(0) for i in range(len(x))]
        ) for j in range(len(x[0]))
    )


loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=collate_fn
)

#不用for来迭代Loader也可以,可以用next(iter(loader))来从迭代器中去一个元素
features,value = next(iter(loader))
print('next(iter(loader))={},{}'.format(features,value))
'''
这个print的输出为:
next(iter(loader))=tensor([[0, 1, 2],
        [1, 2, 3],
        [2, 3, 4]]),tensor([[0],
        [1],
        [2]])
'''

item = next(iter(loader))
print('next(iter(loader))={}'.format(item))
'''
这个print的输出为:
item=<generator object collate_fn.<locals>.<genexpr> at 0x7fa5d85a0ad0>
'''

for item in loader:
    print('item={}'.format(item))
    #输出为item=<generator object collate_fn.<locals>.<genexpr> at 0x7f8308da8950>

print(len(loader))  #输出为4

a = [(1,1,),(2,2,),(3,3)]
print('next(iter(a))={}'.format(next(iter(a))))

dataloader本质上是一个能够迭代的对象,迭代元素由collate函数返回

Basically iter() calls the __iter__() method on the dataloader which returns an iterator. next() then calls the __next__() method on that iterator to get the first iteration. Running next() again will get the second item of the iterator, etc.

当训练数据很多时,使用迭代器可以有效避免一次都load到内存里占用内容。一般来说,每次next会触发一次collate函数,但这一点不是绝对的,有时候为了速度会预先多加载几个sample也就是预先会多调用几次collate函数。因此,当数据很多时,用python3 -m pdb进行debug的时候,如果断点打在collate函数里,pdb可能会崩

最后结合dataloader和dataset,写一个实用的拿batch代码

注意几点:

  • tokenizer经常被放到dataset里面完成
  • 下面代码提供一个拼dic的collate函数
  • dataloader可以使用prefetch factor
import os
import pandas as pd
import json
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader
from transformers import (AutoConfig, AutoTokenizer, InstructBlipProcessor)
import torch
from collections import defaultdict

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file = 'path/to/ann.json'):
        self.ann_lst = json.load(open(annotations_file, "r"))
        model_type = 'Salesforce/instructblip-flan-t5-xxl'
        self.processor = InstructBlipProcessor.from_pretrained(model_type)

    def __len__(self):
        return len(self.ann_lst)

    def __getitem__(self, idx):
        def load_single_image(local_image_path):
            image = [Image.open(local_image_path).convert('RGB')
            ]
            return image
        image_path = self.ann_lst[idx].get('image', '')
        gid_imageidx = self.ann_lst[idx].get('gid_imageidx', '')
        image_url = self.ann_lst[idx].get('image_url', '')
        label = self.ann_lst[idx].get('label','')
        cols = gid_imageidx.split('_')
        gid,imageidx = int(cols[0]),int(cols[1])
        image = load_single_image(image_path)
        image_path_name,image_url_name,group_id_name,imageidx_name,prompt_name,no_yes_logit_name,output_text_name,label_name = 'image_path','image_url','group_id','imageidx','prompt','no_yes_logit','output_text','label'
        image_tensor_name,qformer_input_ids_name,qformer_attention_mask_name = 'image_tensor','qformer_input_ids','qformer_attention_mask'

        prompt_lst = [
                "Question: Is a cat in the image? Answer:",
                "Question: Is a dog in the image? Answer:",
                ]
        text_bs = len(prompt_lst)
        inputs = self.processor(image, prompt_lst, return_tensors="pt", padding=True).to("cuda")
        image = inputs['pixel_values'].repeat_interleave(text_bs, dim=0)
        qformer_input_ids = inputs['qformer_input_ids']
        qformer_attention_mask = inputs['qformer_attention_mask']
        dic = {
            group_id_name:gid,
            image_path_name:image_path,
            image_url_name:image_url,
            label_name:label,
            image_tensor_name:image,
            qformer_input_ids_name:qformer_input_ids,
            qformer_attention_mask_name:qformer_attention_mask
        }
        return dic

def collate_fn(lst):
    dic = defaultdict(list)
    res = {}
    for lst_item in lst:
        for key,v in lst_item.items():
            dic[key].append(v)
    for key, v in dic.items():
        if torch.is_tensor(v[0]):
            res[key] = torch.cat(dic[key])
        else:
            res[key] = dic[key]
    return res

dataset = CustomImageDataset()
test_dataloader = DataLoader(dataset, batch_size=3, shuffle=False, collate_fn=collate_fn, prefetch_factor=6, num_workers=3)
for batch in test_dataloader:
    print(batch)
    import pdb; pdb.set_trace();

json的样例格式如下:


[
 {
  "image": "image/path/1.jpg",
  "content": "title1",
  "gid_imageidx": "123_0",
  "image_url": "url1",
  "label": "0"
 },
 {
  "image": "image/path/2.jpg",
  "content": "title1",
  "gid_imageidx": "123_1",
  "image_url": "url2",
  "label": "0"
 },
 {
  "image": "image/path/3.jpg",
  "content": "title1",
  "gid_imageidx": "123_2",
  "image_url": "url3",
  "label": "0"
 }
]

collate_fun如果exception怎么办?

如果直接for batch in dataloader: exception循环就退了,可以写成下面这样:

import torch.utils.data as Data
import numpy as np
import torch

test = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
# inputting shape is [10,3]
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
# target shape is [10,1]
torch_dataset = Data.TensorDataset(inputing, target)
batch = 3

def collate_fn(x):
    # 关于输入:
    # x的最外层是batch个数
    ## x的次外层是特征和label,因此len(x[0])其实是2
    ### x[i][j].unsqueeze(0)实际是为了torch.cat准备的,或者使用torch.stack也行。torch.stack会增加维度,torch.cat不增加维度
    # for j in range(len(x[0])):
    #     print('tmp1={}'.format([x[i][j].unsqueeze(0) for i in range(len(x))]))
    #     print('tmp2={}'.format(torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))])))
    # 关于输出:
    # 输出的最外层是个tuple,分别对应features和label
    ## 输出的次外层是tensor,features的shape是[batch_size, embedding_size],label的shape是[batch_size, 1]
    if len(x) == 3:
        print(0/0)
    return (
        torch.cat(
            [x[i][j].unsqueeze(0) for i in range(len(x))]
        ) for j in range(len(x[0]))
    )

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=collate_fn
)

loader_iter = iter(loader)
while(True):
    try:
        batch = next(loader_iter, None)
        if not batch:
            break
        print('batch={}'.format(next(batch)))
    except Exception as e:
        print('batch={} exception={}'.format(batch, e))

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值