在训练神经网络的过程中,常常需要fine tune一个现有的网络,首先是需要对输入数据进行预处理,包括有:
- 对尺寸大小进行处理
- 将正负例和测试的data&label保存为h5文件
- 将h5文件中data&label对应的书序打乱
实现代码如下:
1. 导包以及VGG网络初始化
import numpy as np
import matplotlib.pyplot as plt
import skimage
import skimage.io
import skimage.transform
import os
import h5py
%matplotlib inline
plt.rcParams['figure.figsize']=(10,10)
plt.rcParams['image.interpolation']='nearest'
plt.rcParams['image.cmap']='gray'
VGG_MEAN = [103.939, 116.779, 123.68]
2.处理图片RGB三通道
def preprocess(img):
out = np.copy(img) * 255
out = out[:, :, [2,1,0]] # swap channel from RGB to BGR
# sub mean
out[:,:,0] -= VGG_MEAN[0]
out[:,:,1] -= VGG_MEAN[1]
out[:,:,2] -= VGG_MEAN[2]
out = out.transpose((2,0,1)) # h, w, c -> c, h, w
return out
3.像素归一化
def deprocess(img):
out = np.copy(img)
out = out.transpose((1,2,0)) # c, h, w -> h, w, c
out[:,:,0] += VGG_MEAN[0]
out[:,:,1] += VGG_MEAN[1]
out[:,:,2] += VGG_MEAN[2]
out = out[:, :, [2,1,0]]
out /= 255
return out
4.尺寸处理
# returns image of shape [224, 224, 3]
# [height, width, depth]
def load_image(path):
# load image
img = skimage.io.imread(path)
img = img / 255.0
assert (0 <= img).all() and (img <= 1.0).all()
#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224
resized_img = skimage.transform.resize(crop_img, (224, 224))
return resized_img
5.循环遍历文件保存数据以及label
关键代码:
保存count值方便后续使用以及检查:
imgData_count = 0
imgTest_count = 0
FilePrefixlist = [] #存取文件前缀名的列表
#分别将list中出现的image名字和label保存在不同的矩阵
PositiveList = np.loadtxt(r'plane.txt',dtype=np.int)
#获取文件的前缀名,前缀名为string类型
with open(r'plane.txt', 'r') as f:
while True:
line = f.readline() #逐行读取
if not line:
break
linesplit = line.split(' ')
FilePrefixlist.append(linesplit[0]) #只取得第一列的数据即文件的前缀名
labelPositiveList = PositiveList[:,1]
#统计正例中保存为训练集的个数
labelPositiveCount=np.sum(labelPositiveList==1)
labelNegativeCount=np.sum(labelPositiveList==-1)
#初始化训练集和测试集的data和label
imgData = np.zeros([labelPositiveCount+190,3,224,224],dtype= np.float32)
label = []
imgTest = np.zeros([labelNegativeCount+95,3,224,224],dtype= np.float32)
labelTest =[]
接下里开始正式读数据和label,以其中某一个文件数据为例:
#通过读正类脚本文件将正类中train和test的保存到对应data中
for index in range(len(FilePrefixlist)):
line=FilePrefixlist[index]
#如果label=1,那么是训练集
if labelPositiveList[index]==1 :
imgData[imgData_count,:,:,:]=preprocess(load_image(path+'/'+line+'.jpg'))
label.append(1)
imgData_count = imgData_count+1
#否则label就是-1,代表这是一个测试集的数据,放在测试集中
else:
imgTest[imgTest_count,:,:,:]=preprocess(load_image(path+'/'+line+'.jpg'))
labelTest.append(1)
imgTest_count = imgTest_count+1
上述过程将所有data存在numpy数组里面,label存在list中用append()方式追加,于是需要将list转变为numpy数组:
#将label列表变为numpy
label = np.array(label)
labelTest = np.array(labelTest)
使用shuffle打乱顺序:
#打乱h5文件训练集正负例顺序
index = [i for i in range(len(imgData))]
np.random.shuffle(index)
imgData = imgData[index]
label = label[index]
创建h5文件,放入data和label:
f = h5py.File('aeroplane_train.h5','w')#相对路径,绝对路径会报错
f['data']=imgData
f['label']=label
f.close()
#HDF5的读取:
f = h5py.File('aeroplane_train.h5','r') #打开h5文件
f.keys() #可以查看所有的主键
a = f['data'][:] #取出主键为data的所有的键值
f.close()
数据预处理以及保存过程关键代码如上所示。
在编码中遇到一些小坑:
1、win与linux在写路径是正反斜杠”/”“\”的问题,win下复制的路径与自己添加的完整路径的斜杠方向不同。。。
2、在loadtxt的时候,由于\t或者\n会识别为转义字符,于是需要在路径前加上r,否则会报错,例如:
PositiveList = np.loadtxt(r'C:\Users\Administrator\plane.txt',dtype=np.int)