一、数据预处理和格式转换
1、在kaggle上下载了一个数据踩了一个巨坑,图像有三通道和四通道的图像,导致后续tfrecord读取造成唯独不匹配,用opencv写了一个脚本转换图像
#-*- encoding:utf-8 -*-
#将4通道图像转换成3通道
import os
import cv2
save_dir = '../train_convert'
origin_dir = '../train'
i = 0
for folder in os.listdir(origin_dir):
os.makedirs(save_dir + '/' + folder)
for img_file in os.listdir(os.path.join(origin_dir,folder)):
image = cv2.imread(os.path.join(origin_dir,folder,img_file))
#if image.shape[2] == 4:
# print(os.path.join(origin_dir,folder,img_file))
cv2.imwrite(os.path.join(save_dir,folder,img_file),image)
i+=1
print('processing ' + str(i) + 'th class')
2、生成tfrecord文件,文件格式和代码如下:
#-*- encoding:utf-8 -*-
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import config
import os
os.system('rm -rf {}'.format('../dataset'))
os.makedirs('../dataset')
#生成文件索引及其对应类别编号函数
def get_files(file_dir):
labels=[]
images=[]
with open(config.labels_map_dir,'w+') as f:
for i,folder in enumerate(os.listdir(file_dir)):
f.write(folder+' '+str(i)+'\n')#类别名称及其对应的数字
for file in os.listdir(file_dir+'/'+folder):
images.append(file_dir+'/'+folder+'/'+file)
labels.append(i)
temp = np.array([images,labels])
temp = temp.transpose()
np.random.shuffle(temp)#打乱文件的顺序
image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(i) for i in label_list]
return image_list,label_list
#************************************************************
#********************