微信公众号:小白图像与视觉
关于技术、关注
yysilence00
。有问题或建议,请公众号留言。
# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. This directory will be recovered automatically after resetting environment.
!ls /home/aistudio/data
data23668
# 查看工作区文件, 该目录下的变更将会持久保存. 请及时清理不必要的文件, 避免加载过慢.
# View personal work directory. All changes under this directory will be kept even after reset. Please clean unnecessary files in time to speed up environment loading.
!ls /home/aistudio/work
!cd /home/aistudio/data/data23668 && unzip -qo Dataset.zip
!cd /home/aistudio/data/data23668/Dataset && rm -f */.DS_Store # 删除无关文件
import os
import time
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from multiprocessing import cpu_count
from paddle.fluid.dygraph import Pool2D,Conv2D
from paddle.fluid.dygraph import Linear
# 生成图像列表
data_path = '/home/aistudio/data/data23668/Dataset'
character_folders = os.listdir(data_path)
print('character_folders:', character_folders)
if(os.path.exists('./train_data.list')):
os.remove('./train_data.list')
if(os.path.exists('./test_data.list')):
os.remove('./test_data.list')
for character_folder in character_folders:
with open('./train_data.list', 'a') as f_train:
with open('./test_data.list', 'a') as f_test:
if character_folder == '.DS_Store':
continue
character_imgs = os.listdir(os.path.join(data_path,character_folder))
#print((os.path.join(data_path,character_folder)))
#print('character_imgs:', character_imgs)
count = 0
for img in character_imgs:
if img =='.DS_Store':
continue
if count%10 == 0: # cout = 0时
f_test.write(os.path.join(data_path,character_folder,img) + '\t' + character_folder + '\n')
#print((os.path.join(data_path,character_folder,img) + '\t' + character_folder + '\n'))
else: # cout=[1:9:1]时
f_train.write(os.path.join(data_path,character_folder,img) + '\t' + character_folder + '\n')
count +=1
print('列表已生成')
character_folders: ['6', '3', '5', '8', '4', '0', '1', '2', '9', '.DS_Store', '7']
列表已生成
# 定义训练集和测试集的reader
def data_mapper(sample):
img, label = sample
#print(img, label)
img = Image.open(img)
img = img.resize((100, 100), Image.ANTIALIAS)
img = np.array(img).astype('float32')
img = img.transpose((2, 0, 1))
img = img/255.0 #归一化/ 255.0
return img, label
def data_reader(data_list_path):
def reader():
with open(data_list_path, 'r') as f:
lines = f.readlines()
for line in lines:
img,