利用python下载数据_[深度学习] 各种下载深度学习数据集方法(In python)

#-*- coding:utf-8 -*-

__author__ = 'Leo.Z'

#Tensorflow Version:1.14.0

importosimporttensorflow as tffrom PIL importImage

BATCH_SIZE= 128

defread_cifar10(filenames):

label_bytes= 1height= 32width= 32depth= 3image_bytes= height * width *depth

record_bytes= label_bytes +image_bytes#lamda函数体

#def load_transform(x):

## Convert these examples to dense labels and processed images.

#per_record = tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])

#return per_record

#tf v1.14.0版本的FixedLengthRecordDataset(filename_list,bin_data_len)

datasets = tf.data.FixedLengthRecordDataset(filenames=filenames, record_bytes=record_bytes)#是否打乱数据

#datasets.shuffle()

#重复几轮epoches

datasets = datasets.shuffle(buffer_size=BATCH_SIZE).repeat(2).batch(BATCH_SIZE)#使用map,也可使用lamda(注意,后面使用迭代器的时候这里转换为uint8没用,后面还得转一次,否则会报错)

#datasets.map(load_transform)

#datasets.map(lamda x : tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes]))

#创建一起迭代器tf v1.14.0

iter =tf.compat.v1.data.make_one_shot_iterator(datasets)#获取下一条数据(label+image的二进制数据1+32*32*3长度的bytes)

rec =iter.get_next()#这里转uint8才生效,在map中转貌似有问题?

rec =tf.decode_raw(rec, tf.uint8)

label=tf.cast(tf.slice(rec, [0, 0], [BATCH_SIZE, label_bytes]), tf.int32)#从第二个字节开始获取图片二进制数据大小为32*32*3

depth_major =tf.reshape(

tf.slice(rec, [0, label_bytes], [BATCH_SIZE, image_bytes]),

[BATCH_SIZE, depth, height, width])#将维度变换顺序,变为[H,W,C]

image = tf.transpose(depth_major, [0, 2, 3, 1])#返回获取到的label和image组成的元组

return(label, image)defget_data_from_files(data_dir):#filenames一共5个,从data_batch_1.bin到data_batch_5.bin

#读入的都是训练图像

filenames = [os.path.join(data_dir, 'data_batch_%d.bin' %i)for i in range(1, 6)]#判断文件是否存在

for f infilenames:if nottf.io.gfile.exists(f):raise ValueError('Failed to find file:' +f)#获取一张图片数据的数据,格式为(label,image)

data_tuple =read_cifar10(filenames)returndata_tupleif __name__ == "__main__":#获取label和type的对应关系

label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

name_list= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

label_map=dict(zip(label_list, name_list))

with tf.compat.v1.Session() as sess:

batch_data= get_data_from_files('cifar10_dir/cifar-10-batches-bin')#在之前的旧版本中,因为使用了filename_queue,所以要使用start_queue_runners进行数据填充

#1.14.0由于没有使用filename_queue所以不需要

#threads = tf.train.start_queue_runners(sess=sess)

sess.run(tf.compat.v1.global_variables_initializer())#创建一个文件夹用于存放图片

if not os.path.exists('cifar10_dir/raw'):

os.mkdir('cifar10_dir/raw')#存放30张,以index-typename.jpg命名,例如1-frog.jpg

for i in range(30):#获取一个batch的数据,BATCH_SIZE

#batch_data中包含一个batch的image和label

batch_data_tuple =sess.run(batch_data)#打印(128, 1)

print(batch_data_tuple[0].shape)#打印(128, 32, 32, 3)

print(batch_data_tuple[1].shape)#每个batch存放第一张图片作为实验

Image.fromarray(batch_data_tuple[1][0]).save("cifar10_dir/raw/{index}-{type}.jpg".format(

index=i, type=label_map[batch_data_tuple[0][0][0]]))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值