pytorch自定义Dataset

因为需要读取大量数据到神经网络里进行训练,之前一直使用的keras.fit不管用了,后来发现pytorch自带的Dataset和Dataloader能很好的解决这个问题。如果使用tensorflow的话,需要使用tf.data.Dataset.from_tensor_slices().map()方法或者使用队列来解决这个问题,
tensorflow自定义Dataset教程链接:
http://www.51zixue.net/TensorFlow/765.html

在网上找了一些教程,只写了一些基础的代码,没有讲清楚为啥这么写,有些bug也没有提示。
这里写一下我自己的理解:
首先自定义Dataset必须要写一个继承from torch.utils.data import Dataset的类,其中除了init方法以外还有两个方法,__getitem__()和__len__(),可以这么理解:在使用pytorch自带的Dataloader把Dataset包裹起来调用的时候,会认为这个Dataset一共有的数据量就是__len__()的返回值,比如Dataloader的batch参数为8,即一次读取8个数据,它就会产生8个不同的数值,把这些数值作为__getitem__()的参数输入进去调用,然后把返回的每次返回的数据,共8个,打包好来给用户。

其中,get_item()的返回值也没要必须是(一个数据+一个label)的形式,只要有返回值就可以,只不过相对应的,在遍历Dataloader,其实也就是在遍历这些返回值,只要做好相应处理即可

其中我遇到了两个报错是和这部分有关的

ValueError: num_samples should be a positive integer value, but got num_samples=0

这个原因比较简单,就是 __len__(self)返回值是0,导致程序认为不存在样本数量,关注修改这部分即可

第二个:
UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.

这个报错的原因比较复杂,主要原因就是报错里说的,The given NumPy array is not writeable。我在本地测试正常,但是把程序部署到gpu算力平台上时出现了这个问题,解决方法是在

__getitem__(self, index)

这个函数的返回值里,把原来返回的feature用np.array()包裹,注意feature原本就是numpy数组,这里再调用一次np.array是为了达到copy的效果,从而解决这个问题。

下面附上整段的代码

# 准备pytorch的数据
from torch.utils.data import Dataset, DataLoader
from OSutils import get_data_path, load_jsondata
from ByteSequencesFeature import byte_sequences_feature
from torch.utils.data import DataLoader
import numpy as np
import torch


def data_loader_multilabel(file_path='', label_dict={}):
    # 用于读取多标签的情况
    file_md5 = file_path.split('/')[-1]
    return byte_sequences_feature(file_path), label_dict.get(file_md5)


def data_loader(file_path='', label_dict={}):
    # 用于读取单标签的情况
    file_md5 = file_path.split('/')[-1]
    if file_md5 in label_dict:
        return byte_sequences_feature(file_path), 1
    else:
        return byte_sequences_feature(file_path), 0


class MalconvDataSet(Dataset):

    def __init__(self, black_samples_dir="black_samples/", white_samples_dir='white_samples/',
                 label_dict_path='label_dict.json', label_type="single", valid=False, valid_size=0.2, seed=207):

        self.file_list = get_data_path(black_samples_dir)
        self.loader = data_loader_multilabel

        if label_type == "single":
            self.loader = data_loader
            self.file_list += get_data_path(white_samples_dir)

        if label_type == "predict":
            self.label_dict = {}
            self.loader = data_loader
        else:
            self.label_dict = load_jsondata(label_dict_path)
            np.random.seed(seed)
            np.random.shuffle(self.file_list)

        # 如果是需要测试集,就在原来的基础上分割
        # 因为设定了随机种子,所以分割的结果是一样的


        valid_cut = int((1 - valid_size) * len(self.file_list))
        if valid:
            self.file_list = self.file_list[valid_cut:]
        else:
            self.file_list = self.file_list[:valid_cut]

    def __getitem__(self, index):
        file_path = self.file_list[index]
        feature, label = self.loader(file_path, self.label_dict)
        return np.array(feature), label

    def __len__(self):
        return len(self.file_list)

调用的代码:
使用数据集,这里划分测试集的功能我在DataSet里已经定义好了,所以只需要更改valid参数即可

from torch.utils.data import DataLoader
train_data_loader = DataLoader(
    MalconvDataSet(black_samples_dir=black_samples_dir, white_samples_dir=white_samples_dir,
                   label_dict_path=label_dict_path, label_type=task_type, valid=False,
                   valid_size=valid_size, seed=207), batch_size=8, shuffle=True, )
test_data_loader = DataLoader(
    MalconvDataSet(black_samples_dir=black_samples_dir, white_samples_dir=white_samples_dir,
                   label_dict_path=label_dict_path, label_type=task_type, valid=True,
                   valid_size=valid_size, seed=207), batch_size=8, shuffle=True, )
 for step, batch_data in enumerate(train_data_loader):
        exe_input = batch_data[0].cuda() if use_gpu else batch_data[0]
        exe_input = exe_input.long()

        label = batch_data[1].cuda() if use_gpu else batch_data[1]
        label = label.long()
        pred = malconv(exe_input)
        loss = model_loss(pred, label)
        adam_optim.zero_grad()
        loss.backward()
        adam_optim.step()

        with torch.no_grad():
            pred = F.softmax(pred, dim=-1)
            pred = pred.argmax(1)
            current_acc = (pred == label).float().mean().item()
            current_loss = loss.item()
            history['tr_loss'].append(current_loss)
            history['tr_acc'].append(current_acc)
            if step % display_step == 0:
                print(step_msg.format(step, np.mean(history['tr_loss']), np.mean(history['tr_acc'])), flush=True)
                print("current_step_acc:{:.4f}".format(current_acc), flush=True)
            if step % test_step == 0:
                # 每过一段时间就进行测试和保存记录
                break
  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值