本文主要介绍数据输入时的俩个函数
数据输入
- 1.数据输入处理
- 1.1数据输入
- 参数与返回
- 1.1.1实现代码
- 1.1.2相关函数
- (1)os.listdir(path) 返回一个包含由 path 指定目录中条目名称组成的列表。
- (2)os.path.join(path,*paths) 智能地拼接一个或多个路径部分。 返回值是 path 和 *paths 的所有成员的拼接
- (3)os.path.isdir(path ) 如果 path 是现有的目录,则返回 True。
- (4)os.path.splitext(path) 将路径名称 path 拆分为 (root, ext) 对使得 root + ext == path,并且扩展名 ext 为空或以句点打头并最多只包含一个句点。
- (5)os.path.join(path,*paths)智能地拼接一个或多个路径部分。 返回值是 path 和 *paths 的所有成员的拼接
- (6)os.rename(src, dst, *, src_dir_fd=None, dst_dir_fd=None)将文件或目录 src 重命名为 dst。
- (7)np.hstack 是 numpy.stack 函数的变体,它通过水平堆叠来生成数组。
- (8)np.array创建一个数组
- (9)numpy.transpose 函数用于对换数组的维度
- (10)np.random.shuffle随机打乱数据
- 1.2批次划分
- 参数与返回:
- 1.2.1实现代码
- 1.2.2相关函数
- (1)tf.cast数据类型转换
- (2)tf.train.slice_input_producer读取数据到队列
- (3)tf.io.read_file读取并输出输入文件名的全部内容。
- (4)tf.image.decode_jpeg解码图像
- (5)tf.image.resize_with_crop_or_pad(image, image_W, image_H)将图片以图片中心进行裁剪或者扩充为 指定的image_W,image_H
- (6)tf.image.per_image_standardization(image)线性缩放image以具有零均值和单位范数.
- (7)tf.train.batch([image, label],batch_size = batch_size, num_threads = 64, capacity = capacity)批量读取数据
- (8) tf.reshape
这里我们继续进行cnn卷积神经网络的应用过程分享。相关的原理可以参考这位博主的文章《理解CNN卷积神经网络原理》
1.数据输入处理
数据输入也可看作是模型的第一层——输入层。这里将主要是对数据进行图像和标签的绑定,并把图像数据进行解码和标准化,将图像数据处理为统一格式,以便后续模型的计算处理。
1.1数据输入
首先介绍一下本次训练对于训练集的格式要求,训练数据按照不同的类别将图片命名为【类别名称+.+随机后缀防止重名+.jpg】
这里是通过python的代码直接批量进行重命名直接处理的。
import os
path = "D:/[code]/python/LIANXI/surgical_dataset_4classes/test/forceps1"
#需要重命名的文件夹路径
filelist = os.listdir(path)
count = 0
for file in filelist:
print(file)
for file in filelist:
Olddir = os.path.join(path, file)
if os.path.isdir(Olddir):
continue
filename = os.path.splitext(file)[0]
filetype = os.path.splitext(file)[1]
Newdir = os.path.join(path, "forceps1."+str(count) + filetype)
os.rename(Olddir, Newdir)
count += 1
一般都是直接给定训练集的路径,我们读取该路径下的所有文件名,将文件名按照"."切开成一个列表,根据命名的结构,该列表中的第一项即为该图片所属的类别。并将该图像的路径和该图像的类编一起放入一个队列,并将其顺序打乱。打乱其顺序的目的是为了提高训练模型的泛性,防止其进入局部最优解。
参数与返回
参数:训练集路径
返回:图像列表image_list,标签列表label_list
1.1.1实现代码
import tensorflow as tf
import os
import numpy as np
def get_files(file_dir):
forceps1 = [] #存放该类别图像路径
label_forceps1 = [] #存放该类别图像标签 0
scissors1 = []
label_scissors1 = [] #存放该类别图像标签 1
scissors2 = []
label_scissors2 = [] #存放该类别图像标签 2
tweezers = []
label_tweezers = [] #存放该类别图像标签 3
for file in os.listdir(file_dir):
name = file.split(sep='.')
print(name)
if 'forceps1' in name[0]:
forceps1.append(file_dir + file)
label_forceps1.append(0)
elif 'scissors1' in name[0]:
scissors1.append(file_dir + file)
label_scissors1.append(1)
elif 'scissors2' in name[0]:
scissors2.append(file_dir + file)
label_scissors2.append(2)
elif 'tweezers' in name[0]:
tweezers.append(file_dir + file)
label_tweezers.append(3)
image_list = np.hstack((forceps1,scissors1,scissors2,tweezers))
label_list = np.hstack((label_forceps1,label_scissors1,label_scissors2,label_tweezers))
temp = np.array([image_list,label_list])
temp = temp.transpose()
np.random.shuffle(temp)
image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(float(i)) for i in label_list]
return image_list,label_list
1.1.2相关函数
(1)os.listdir(path) 返回一个包含由 path 指定目录中条目名称组成的列表。
https://docs.python.org/zh-cn/3/library/os.html?highlight=os%20listdir#os.listdir
(2)os.path.join(path,*paths) 智能地拼接一个或多个路径部分。 返回值是 path 和 *paths 的所有成员的拼接
https://docs.python.org/zh-cn/3/library/os.path.html?highlight=os%20path%20join#os.path.join
(3)os.path.isdir(path ) 如果 path 是现有的目录,则返回 True。
https://docs.python.org/zh-cn/3/library/os.path.html?highlight=os%20path%20isdir#os.path.isdir
(4)os.path.splitext(path) 将路径名称 path 拆分为 (root, ext) 对使得 root + ext == path,并且扩展名 ext 为空或以句点打头并最多只包含一个句点。
https://docs.python.org/zh-cn/3/library/os.path.html?highlight=os%20path%20splitext#os.path.splitext
(5)os.path.join(path,*paths)智能地拼接一个或多个路径部分。 返回值是 path 和 *paths 的所有成员的拼接
https://docs.python.org/zh-cn/3/library/os.path.html?highlight=os%20path%20join#os.path.join
(6)os.rename(src, dst, *, src_dir_fd=None, dst_dir_fd=None)将文件或目录 src 重命名为 dst。
https://docs.python.org/zh-cn/3/library/os.html?highlight=os%20rename#os.rename
(7)np.hstack 是 numpy.stack 函数的变体,它通过水平堆叠来生成数组。
https://www.runoob.com/numpy/numpy-array-manipulation.html
(8)np.array创建一个数组
(9)numpy.transpose 函数用于对换数组的维度
https://www.runoob.com/numpy/numpy-array-manipulation.html
(10)np.random.shuffle随机打乱数据
https://blog.csdn.net/weixin_43896259/article/details/106116955
1.2批次划分
按批次读取数据并转换为统一tf可以识别的格式。
参数与返回:
参数:
image:图像集合
label:标签集合
image_W:规定转换后的图像宽
image_H:规定转换后的图像高
batch_size:一个批次的数据个数
capacity:队列中最多容纳图片的个数
返回值:
按照批次返回标准化后的数据集,标签集。
1.2.1实现代码
def get_batch(image,label,image_W,image_H,batch_size,capacity):
# 转换数据为 ts 能识别的格式
image = tf.cast(image,tf.string)
label = tf.cast(label, tf.int32)
# 将image 和 label 放到队列里
input_queue = tf.train.slice_input_producer([image,label])
label = input_queue[1]
# 读取图片的全部信息
image_contents = tf.io.read_file(input_queue[0])
# 把图片解码,channels =3 为彩色图片, r,g ,b 黑白图片为 1 ,也可以理解为图片的厚度
image = tf.image.decode_jpeg(image_contents,channels =3)
# 将图片以图片中心进行裁剪或者扩充为 指定的image_W,image_H
image = tf.image.resize_with_crop_or_pad(image, image_W, image_H)
# 对数据进行标准化,标准化,就是减去它的均值,除以他的方差
image = tf.image.per_image_standardization(image)
# 生成批次 num_threads 有多少个线程根据电脑配置设置 capacity 队列中 最多容纳图片的个数 tf.training.shuffle_batch 打乱顺序,
image_batch, label_batch = tf.train.batch([image, label],batch_size = batch_size, num_threads = 64, capacity = capacity)
# 重新定义下 label_batch 的形状
label_batch = tf.reshape(label_batch , [batch_size])
# 转化图片
image_batch = tf.cast(image_batch,tf.float32)
return image_batch, label_batch
1.2.2相关函数
(1)tf.cast数据类型转换
https://blog.csdn.net/ddy_sweety/article/details/80408000
(2)tf.train.slice_input_producer读取数据到队列
https://blog.csdn.net/kk123k/article/details/86772813
(3)tf.io.read_file读取并输出输入文件名的全部内容。
https://www.w3cschool.cn/tensorflow_python/tf_io_read_file.html
(4)tf.image.decode_jpeg解码图像
TensorFlow 提供操作来解码和编码 JPEG 和 PNG 格式。编码图像由标量字符串 Tensors 表示,解码图像由shape为[height, width, channels]的3-D uint8张量表示。(PNG也支持 uint16)
参考:TensorFlow图像操作
(5)tf.image.resize_with_crop_or_pad(image, image_W, image_H)将图片以图片中心进行裁剪或者扩充为 指定的image_W,image_H
(6)tf.image.per_image_standardization(image)线性缩放image以具有零均值和单位范数.
这个操作计算(x - mean) / adjusted_stddev, 其中:
mean是图像中所有值的平均值
adjusted_stddev= max(stddev, 1.0/sqrt(image.NumElements())).
stddev是image中所有值的标准偏差.处理统一图像时,它被限制为零,以防止除以0.【参数】: image:形状为[height, width, channels]的三维张量.
【返回】: 与image具有相同形状的标准化的图像.
【可能引发的异常】: ValueError:如果’image’的形状与此功能不兼容.
(7)tf.train.batch([image, label],batch_size = batch_size, num_threads = 64, capacity = capacity)批量读取数据
参考:https://blog.csdn.net/sinat_29957455/article/details/83152823
(8) tf.reshape
参考:https://blog.csdn.net/m0_37592397/article/details/78695318