深度学习笔记(十六)---几种数据形式的灵活读取

数据的读取是我们进行工作的第一步,在我们拿到各种各样的数据时,首先要知道数据的格式以及label,对应的种类,数量,下面就先介绍常用数据的读取方式。从数据角度分两种,一是ndarray格式的纯数值数据的读写,二是对象(数据结构)如dict的文件存取。

导入将要使用的函数包

import cv2
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
from torchvision import transforms as tfs 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

使用ImageFolder读取图片,下面尽量在程序注释

#三个文件夹,每个文件夹一共有三个图片最为例子
folder_set = ImageFolder('./chapter8_pytorch_advances\example_data\image')

#查看名称和类别的下标的对应
print(folder_set.class_to_idx)
#得到所有图片的名字和标签
print(folder_set.imgs)

im,label = folder_set[0]
#输出标签已经显示图片
#print('\n',label)
#plt.imshow(im)
#plt.show()

#传入数据预处理方式
data_tf = tfs.ToTensor()
folder_set = ImageFolder('./chapter8_pytorch_advances/example_data/image',transform=data_tf)
im ,label = folder_set[0]
print(im)
tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

        [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
         [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
         [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
         ...,
         [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
         [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
         [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

        [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
         [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
         [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.1647],
         ...,
         [0.3765, 0.1333, 0.1020,  ..., 0.2745, 0.0275, 0.0784],
         [0.3765, 0.1647, 0.1176,  ..., 0.3686, 0.1333, 0.1333],
         [0.4549, 0.3686, 0.3412,  ..., 0.5490, 0.3294, 0.2824]]])

 通过这种方式能够非常方便的访问每个数据点 

下面将读取TXT文件:

#定义一个子类叫custom_dataset,继承与dataset
class custom_dataset(Dataset):
    def __init__(self,txt_path,transform=None):
        self.transform = transform #传入数据预处理
        with open(txt_path,'r') as f:
            lines = f.readlines()

        self.img_list = [i.split()[0] for i in lines]  #得到所有的图像的名字
        self.label_list = [i.split()[1] for i in lines]  #得到所有的label

    def __getitem__(self,idx):  #根据idx取出其中一个
        img = self.img_list[idx]
        label = self.label_list[idx]
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):  #总数据多少
        return len(self.label_list)

txt_dataset = custom_dataset('./chapter8_pytorch_advances/example_data/train.txt') # 读入 txt 文件

#取得其中一个数据
data,label = txt_dataset[0]
print(data,'\n',label)

# 再取一个
data2, label2 = txt_dataset[34]
print(data2)
print(label2)

通过这种方式我们也能够非常方便的定义一个数据读入,同时也能够方便的定义数据预处理

#用自定义的数据读入举例子
train_data1 = DataLoader(folder_set,batch_size=2,shuffle=True)  #将2个数据作为一个batch
for im ,label in train_data1:#访问迭代器
    print(label)
tensor([1, 0])
tensor([0, 2])
tensor([0, 2])
tensor([1, 1])
tensor([2])

通过训练我们可以访问到所有的数据,这些数据被分为了 5 个 batch,前面 4 个都有两个数据,最后一个 batch 只有一个数据,因为一共有 9 个数据,同时顺序也被打乱了

下面我们用自定义的数据读入举例子:

train_data2 = DataLoader(txt_dataset,8,True)  #batch size 设置为8

im,label = next(iter(train_data2))  #使用这种方式访问迭代器中第一个batch的数据
#结果
('834_1.png',
 '957_5.png',
 '977_19.png',
 '151_7.png',
 '455_4.png',
 '997_7.png',
 '5097_1.png',
 '622_1.png')

现在有一个需求,希望能够将上面一个 batch 输出的 label 补成相同的长度,短的 label 用 0 填充,我们就需要使用 collate_fn 来自定义我们 batch 的处理方式,下面直接举例子:

def collate_fn (batch):
    batch.sort(key=lambda x: len(x[1]),reverse=True)#将数据集按照label的长度从大到小排序
    img,label = zip(* batch)#将数据和label进行匹配取出
    #填充
    pad_label = []
    lens = []
    max_len = len(label[0])
    for i in range(len(label)):
        temp_label = label[i]
        temp_label += '0' * (max_len - len(label[i]))
        pad_label.append(temp_label)
        lens.append(temp_label)
        lens.append(len(label[i]))
    pad_label
    return img,pad_label,lens  #输出label的真是长度

train_data3 = DataLoader(txt_dataset, 8, True, collate_fn=collate_fn) # batch size 设置为 8
im,label,lens = next(iter(train_data3))

print(im,'\n',label,'\n',lens)

下面对其他数据类型的读取方式进行补充:

数值数据的读写

  •  .bin格式,np.tofile() 和 np.fromfile()
import numpy as np
a = np.random.randint(0, 100, size=(10000, 5000))
print(a.dtype, a.shape)  # int32 (10000, 5000) 下同

a.tofile('data/a.bin')
b = np.fromfile('data/a.bin', dtype=np.int32)  # 需要设置正确dtype
print(b.shape)           # (50000000,) 读入数据是一维的,需要reshape

 

  • .npy格式,np.save() 和 np.load() 

numpy专用的二进制格式保存数据,能自动处理变量type和size

np.save('data/a.npy', a)
b = np.load('data/a.npy')
print(b.shape)           # (10000, 5000)

 

  • .txt格式 (或.csv/.xlsx),np.savetxt() 和 np.loadtxt() 

csv或xlsx可以借助python其他库工具实现。csv只能存数字,xlsx可以存数字字符。

np.savetxt('data/a.txt', a, fmt='%d', delimiter=',') #设置以整数存储,以逗号隔开
b = np.loadtxt('data/a.txt', delimiter=',')
print(b.shape)           # (10000, 5000)

 

  • .h5格式,h5py.File(data, ‘r’ or ‘w’)
import h5py
f = h5py.File('data/a.h5','w')   #创建一个h5文件,文件指针是f  
f['data'] = a                    #将数据写成data的键值
f.close()                        #关闭文件

f = h5py.File('data/a.h5','r')   #打开h5文件  
# print f.keys()                 #查看所有的键
b = f['data'][:]                 #取出键名为data的键值
f.close()
print(b.shape)           # (10000, 5000)

python上数据存储,推荐h5

几种方法存储读取耗时:

  • 存成bin还要处理类型和矩阵大小,繁琐
  • 存成txt或csv节省空间,但速度太慢
  • 使用np.save()和h5py的方法比较方便快捷,特别是h5py还能存字典对象,用处很大,所以被很多人推荐
  • json存储数据必须是可序列化对象,而pickle虽然能存数值矩阵a,但一般不这么用

参考

对象(数据结构)的存取,关于序列化与反序列化

序列化,把对象(数据结构)序列化成字符串,可以存储在文件中,也就是对象的持久化 
反序列化,序列化的反向操作,把经过序列化的对象(数据结构)加载到内存

python中有个两个序列化模块:json 和 pickle 
也正是python两个相对轻量级的数据持久化方式。

import json

raw_dict = {'key1': 'value1', 'key2': 'value2'} #
wf = open('save', 'w')  # 将dict类型对象序列化存储到文件中           
json.dump(obj=raw_dict, fp=wf)

rf = open('save')      # 将文件中的数据反序列化成内置的dict类型
raw_data = json.load(fp=rf)
print(raw_data)  # 输出{'key1': 'value1', 'key2': 'value2'}
import pickle
#python2有个cpickle模块是用c实现的pickle,速度较快,在python3中已改为了pickle

raw_dict = {'key1': 'value1', 'key2': 'value2'}
wf = open('save.pkl', 'wb')      # 将dict类型对象序列化存储到文件中
pickle.dump(obj=raw_dict, file=wf)

rf = open('save.pkl', 'rb')
raw_data = pickle.load(file=rf)  # 将文件中的数据反序列化成内置的dict类型
print(raw_data)  # 输出{'key1': 'value1', 'key2': 'value2'}

序列化到 ‘save’文件中的对象是这样的: 
{"key1": "value1", "key2": "value2"} 
而序列化到 ‘save.pkl’文件中是这样的: 
8003 7d71 0028 5804 0000 006b 6579 3171 
0158 0600 0000 7661 6c75 6531 7102 5804 
0000 006b 6579 3271 0358 0600 0000 7661 
6c75 6532 7104 752e

存储了几个对象就只能load几次,如果load超过了存储的对象,会抛出EOFError异常。 
这里能够存储的对象可以是任意对象,字典、列表、元组、numpy数组等。

二者的区别:

  1. JSON是文本形式的存储,Pickle则是二进制形式(至少常用二进制)
  2. JSON是人可读的,Pickle不可读
  3. JSON广泛应用于除Python外的其他领域,Pickle是Python独有的
  4. JSON只能dump一些python的内置对象,Pickle可以存储几乎所有对象
  5. 如果偏向应用特别是web应用,可以常用JSON格式,如果偏向算法,尤其是机器学习,则通常使用cPickle,pylearn2库中保存model就是使用这项技术的

但是因为pkl的存取速度比json还要慢,有时存储较大的训练数据库(几G)这种情况,不存对象直接把内容存成明文(一般的文本文件)也要比存pkl好。

参考:Pickle vs JSON — Which is Faster?

补充: 
pandas官网也给python提供了较全的读取文件不同格式的方法

读取存储csv格式文件: 
pd.read_csv() df.to_csv() 
读取存储json格式文件: 
pd.read_json() df.to_json() 
读网页中的表格: 
pd.read_html() df.to_html() 
读取xls文件,有excel版本限制;存储时数据要先存成DataFrame: 
pd.read_excel() pd.to_excel() 
读取存储pickle格式文件: 
pd.read_pickle('foo.pkl') df.to_pickle('foo.pkl') 
读取存储为HDFS文件: 
pd.HDFStore("store.h5") 
df.to_hdf() 
pd.read_hdf()
 

 

 

 

 

 

 

 

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值