20210715:pytorch DataLoader 自定义 sampler

需求:实现batch内正负1:1采样比例,验证这种采样会不会影响模型的最终精度

探索:搜索良久,发现没有比较直接的实现,需要自己重写一下DataLoader中的sampler

1:确定一下DataLoader的定义

 2:确认一下DataLoader, Sampler, Dataset三者的关系

        链接:https://zhuanlan.zhihu.com/p/76893455

                Sampler提供indicies

                Dataset根据indicies提供data

                DataLoader将上面两个组合起来,提供最终的batch训练数据

3:注意事项:自定义sampler后,shuffle不能指定(默认即可)

实现:可以参考文末的链接中的demo,也可以参考本文中的实战例子

import os
import cv2
import random
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from torchvision import transforms, utils



def img_noise(img_data):
    ''' 
        添加高斯噪声,均值为0,方差为0.001
    '''
    image = np.array(img_data)
    image = np.array(image/255, dtype=float)
    noise = np.random.normal(0, 0.00001 ** 0.5, image.shape)
    out = image + noise
    out = np.clip(out, 0, 1.0)
    out = np.uint8(out*255)
    return out


def img_add_stripe(img_data):
    image = np.array(img_data)
    h,w,c = image.shape
    a = np.random.random()

    if a<= 0.5:
        for i in range(0,h-1,11):
            stripe_data = np.uint8(np.ones([1, w, 3])*150)
            image[i,:,:] = stripe_data
            if (i+2)<=(h-1):
                image[i+2,:,:] = stripe_data
    else:
        for i in range(0,w-1,11):
            stripe_data = np.uint8(np.ones([h, 3])*150)
            image[:,i,:] = stripe_data
            if (i+2)<=(w-1):
                image[:,i+2,:] = stripe_data
    return image


def img_gamma(img, para):
    img1 = np.power(img/255, para) * 255
    img1 = img1.astype(np.uint8)
    return img1


def data_augmentation(img, label):
    b = torch.rand((1,1)).item()
    if b>=0.7:
        img = img_noise(img)
    if b >= 0.2 and b <= 0.5:
        img = img_gamma(img, 1)
    # if label == 0:
    #     c = np.random.random()
    #     if c>=0.9:
    #         img = img_add_stripe(img)
    if label == 0:
        p = torch.randint(1,100, (1,1))
        if p.item()<8:
            img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
    return img


def default_loader(path):
    im = cv2.imread(path)
    if im is None:
        print("None:", path)
    if im.shape[0] != 112 or im.shape[1] !=112:
        im = cv2.resize(im,(112, 112), interpolation=cv2.INTER_NEAREST)
    im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
    # im = Image.fromarray(im.astype(np.uint8))

    if im is None:
        return None
    else:
        return im
   
   
# define Dataset. Assume each line in your .txt file is [name/tab/label], for example:0001.jpg 1
class MyDatasets(Dataset):
    def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader, data_augmentation=data_augmentation):
        lines = []
        self.img_name_pos = []
        self.img_name_neg = []
        self.img_label_pos = []
        self.img_label_neg = []
        with open(txt_path) as input_file:
            lines = input_file.readlines()
        for line in lines:
            if int(line.strip().split(' ')[-1])==1:
                self.img_name_pos.append(os.path.join(img_path, line.strip().split(' ')[0]))
                self.img_label_pos.append(1)
            else:
                self.img_name_neg.append(os.path.join(img_path, line.strip().split(' ')[0]))
                self.img_label_neg.append(0)
        
        value = len(self.img_label_pos) -len(self.img_label_neg)  
        if value>=0:
           self.img_name_neg += random.sample(self.img_name_neg, value)
           self.img_label_neg += [0]*value
        else:
            self.img_name_pos += random.sample(self.img_name_pos, -value)
            self.img_label_pos += [1]*(-value)
        
        self.img_name = self.img_name_pos + self.img_name_neg
        self.img_label = self.img_label_pos + self.img_label_neg

        self.data_augmentation = data_augmentation
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader
 

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        img = self.loader(img_name)       

        if self.data_augmentation is not None:
            img = data_augmentation(img, label)     
            
        if self.data_transforms is not None:
            try:
                img = Image.fromarray(img.astype(np.uint8))
                img = self.data_transforms(img)
            except:
                print("Cannot transform image: {}".format(img_name))
        return img, label

 
class MySampler(Sampler):
    def __init__(self, dataset):
        halfway_point = int(len(dataset)/2)
        self.pos_indices = list(range(halfway_point))
        self.neg_indices = list(range(halfway_point, len(dataset)))
        
    def __iter__(self):
        random.shuffle(self.pos_indices)
        random.shuffle(self.neg_indices)
        shuffle_list = []

        new_list = []
        for x,y in zip(self.pos_indices, self.neg_indices):
            new_list.append(x)
            new_list.append(y)
            
        print(self.pos_indices)
        print(self.neg_indices)
        print(new_list)
        return iter(new_list)
    
    def __len__(self):
        return len(self.first_half_indices) + len(self.second_half_indices)
         

# load datasets
def load_mydata(img_path_default, txt_path_default):
    trans_list = [transforms.ColorJitter(brightness=0.5), transforms.ColorJitter(contrast=0.5),
                  transforms.ColorJitter(saturation=0.5), 
                #   transforms.ColorJitter(saturation=0.5), transforms.ColorJitter(hue=0.5),
                  transforms.RandomRotation(5, resample=Image.BILINEAR, expand=False, center=(56, 56))]
    transform = transforms.RandomChoice(trans_list)
    transform = transforms.RandomApply([transform], p=0.2)
    
    data_transforms = { 'train':
                            transforms.Compose([
                                            transform,
                                            transforms.RandomCrop(96),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                            ]),
                        'test':
                            transforms.Compose([
                                            transforms.CenterCrop(96),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                            ])
                        }

    image_datasets = {x: MyDatasets(img_path=img_path_default,
                                    txt_path=(txt_path_default + '/' + x + '.txt'),
                                    data_transforms=data_transforms[x],
                                    dataset=x) for x in ['train', 'test']}


    
    our_sampler = {x:MySampler(image_datasets[x]) for x in ['train', 'test']} 
    dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 sampler=our_sampler[x],
                                                 batch_size=2,
                                                 num_workers = 2,
                                                 ) for x in ['train', 'test']}

    return dataloders['train'], dataloders['test']


list_train = r'/home/mntsde/lilai/imgs_ir'
list_test = r'/home/mntsde/lilai/imgs_ir'
train_loader, test_loader = load_mydata(list_train, list_test)
for epoch in range( 0, 10):
    print("********************************")
    print("epoch: ", epoch)
    print("********************************")
    for i, data in enumerate(train_loader):
        print(i, data[1])
    print("--------------------------------")    
    for i, data in enumerate(test_loader):
        print(i, data[1])

参考:

(1)https://www.scottcondron.com/jupyter/visualisation/audio/2020/12/02/dataloaders-samplers-collate.html#SequentialSampler 

(2)https://blog.csdn.net/u010087338/article/details/117927204

(3)https://github.com/ufoym/imbalanced-dataset-sampler

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: 在使用LSTM(长短期记忆网络)进行文本序列处理时,遇到数据不等长的问题是比较常见的情况。PyTorchDataLoader可以帮助我们有效地处理这种情况。 首先,我们需要将数据进行预处理,并将其转换为适应LSTM模型输入的格式。通常,我们会使用分词技术将文本分割为单词或子词,并为每个单词或子词分配一个唯一的索引。然后,我们可以将每个序列中的单词索引转换为张量,并使用Packing技术将它们打包为一个批次。 其次,要在PyTorch中处理不等长的序列,可以使用Collate函数来自定义一个处理数据的函数。Collate函数以批次数据作为输入,并在其中对数据进行处理。例如,在Collate函数中,我们可以使用torch.nn.utils.rnn.pad_sequence函数对序列进行填充,使它们的长度相等。 然后,我们需要指定一个Sampler来确定每个批次的数据样本。在处理不等长序列时,通常建议使用Sampler来根据数据长度对样本进行排序,以使每个批次的数据样本长度相对接近。 最后,在创建DataLoader对象时,我们可以通过设置参数drop_last=True来避免最后一个批次中的样本长度小于其他批次。这样做可以确保每个批次的数据样本长度一致,并且减少处理不等长序列的复杂性。 综上所述,使用PyTorchDataLoader和一些预处理技术,我们可以有效地处理数据不等长的情况,并将其用于训练和评估LSTM等序列模型。 ### 回答2: 在使用PyTorch中的数据加载器(DataLoader)时,如果我们处理的是不等长的数据序列并使用LSTM模型,我们需要考虑如何处理这种情况。 首先,我们需要确保我们的数据已经预处理为适当的格式。对于不等长的数据序列,我们需要将它们填充或裁剪为相同的长度。一种常见的方法是使用填充(padding)来将所有序列扩展到最长序列的长度。我们可以使用PyTorch的`pad_sequence`函数来实现这一步骤。对于较短的序列,我们可以使用特定的填充值,如0,进行填充。 接下来,我们需要创建一个自定义的数据集类来处理我们的数据。这个类应该提供`__getitem__`和`__len__`方法。在`__getitem__`方法中,我们需要根据索引获取填充后的序列,并返回它们以及对应的标签。我们还可以使用`collate_fn`函数来对获取的批次数据进行进一步处理,以适应LSTM模型的输入要求。 然后,我们可以使用PyTorch的`DataLoader`来加载我们的数据集。在初始化`DataLoader`时,我们需要设置`collate_fn`参数为我们自定义的处理函数,以确保加载器能够正确处理不等长的数据序列。此外,我们还应该选择适当的`batch_size`、`shuffle`和`num_workers`等参数。 最后,在训练模型时,我们需要在LSTM模型的`forward`方法中处理不等长的数据序列。这可以通过在LSTM模型的输入中指定序列的长度或使用动态计算图的方法来实现。 总之,当我们有不等长的数据序列并使用LSTM模型时,我们需要对数据进行适当的预处理,创建自定义的数据集类来处理数据,使用`DataLoader`加载器以及在模型中适当地处理不等长的数据序列。通过这些步骤,我们可以成功处理不等长的数据序列并应用于LSTM模型的训练。 ### 回答3: 在使用PyTorchDataloader加载数据时,遇到数据不等长的情况,并且需要将这些数据传入LSTM模型进行训练。这个问题可以有几种解决方案。 第一种方案是使用PyTorch提供的pad_sequence函数将数据进行填充,使其等长。pad_sequence函数会找到所有数据中最长的序列,然后在其他序列末尾填充0,使它们的长度与最长序列相等。这样处理后的数据可以作为模型的输入进行训练。需要注意的是,LSTM模型需要将数据按照序列长度进行排序,以便在训练过程中使用pack_padded_sequence函数进行处理。 第二种方案是使用torch.nn.utils.rnn.pack_sequence函数将数据打包成一个批次。该函数会将每个序列长度存储下来,并按照序列长度降序排列,再将序列内容打包成一个Tensor。在训练过程中,可以使用pack_padded_sequence函数对打包后的数据进行处理,提高模型的训练效率。 第三种方案是对数据进行随机舍弃或截断,使得所有序列等长。这种方法可能会导致数据丢失一部分信息,但在一定程度上可以减少数据处理的复杂性。 以上是针对数据不等长的情况,在使用PyTorchDataloader加载数据时可以采取的几种方案。根据具体的需求和应用场景,选择合适的方法来处理数据不等长的情况,以提高模型的效果和训练速度。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

微风❤水墨

你的鼓励是我最大的动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值