MindSpore下载、解压网上数据集资源
开启深度第一步:获取数据集
(此处是下载并解压文件,前提是知道文件名与下载路径)
import os
import zipfile
import requests
from matplotlib import pyplot as plt
# 主要包括class10数据集的载入与处理,也可以自定义数据集。
import mindspore.dataset as ds
def data_download():
# 得到目前的文件路径,设置请求头下载文件
this_path = os.getcwd()
filename = 'cifar10_mindspore.zip'
url = 'https://professional-construction.obs.cn-north-4.myhuaweicloud.com/ComputerVision/cifar10_mindspore.zip'
# 使用request下载
print('--------正在使用requests下载---------')
r = requests.get(url)
with open(filename, 'wb') as code:
code.write(r.content)
print('-------------下载完成---------------')
# 将打包的文件解压
print('--------正在使用zipfile解压----------')
with zipfile.ZipFile(filename, 'r') as f:
for file in f.namelist():
f.extract(file, this_path)
# 删除下载的压缩包文件
os.remove(os.path.join(this_path, filename))
print('------------文件已获取完成------------')
if __name__ == '__main__':
# 创建图像标签列表
category_dict = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
current_path = os.getcwd()
data_path = os.path.join(current_path, 'data\\10-verify-bin')
ten_class_ds = ds.Cifar10Dataset(data_path)
# 下载文件
# data_download()
# 设置图像大小
plt.figure(figsize=(8, 8))
# 打印9张子图,这里每次提取出来的数据是随机的
i = 1
# dic是数据集ten_class_ds里的一个元素,它有两个属性--->image对应图像、label对应标签0,1,2....,9
for dic in ten_class_ds.create_dict_iterator(): # iterator是迭代器,简单理解为 将元素随机排成一个序列,确保每个元素只取1次
# 展示的图按3X3排布
plt.subplot(3, 3, i)
# 显示图片,im show要输入数组格式
plt.imshow(dic['image'].asnumpy())
plt.xticks([])
plt.yticks([])
plt.axis('off')
# 显示图片,自定义的字典category_dict 要输入数字格式,以下查获的label 先转化为数组再通过sum转化为数字
plt.title(category_dict[dic['label'].asnumpy().sum()])
i += 1
if i > 9:
break
plt.show()