数据集在kaggle官网获取。不知为何CSDN资源没办法上传。
首先是制作数据集
input_data.py
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
'''
从原始数据集中获取图片和标签
对所得图像进行预处理
'''
class readData(object):
def __init__(self, io='../data/train/'):
"""
初始化
:param io: 数据集路径
"""
self.val_dataset = None
self.train_dataset = None
self.io = io
self.image_list = None # 图像路径
self.label_list = None # 对应标签
self.label_names = ['cat', 'dog']
def get_files(self):
# return: 乱序后的图片和标签
cats = []
label_cats = []
dogs = []
label_dogs = []
file_dir = self.io
# 载入数据路径并写入标签值
for file in os.listdir(file_dir):
name = file.split(sep='.')
if name[0] == 'cat':
cats.append(file_dir + file)
label_cats.append(0) # 猫猫置0
else:
dogs.append(file_dir + file)
label_dogs.append(1) # 狗狗置1
print("There are %d cats\nThere are %d dogs" % (len(cats), len(dogs)))
# 合并猫狗图片路径和标签
image_list = np.hstack((cats, dogs))
label_list = np.hstack((label_cats, label_dogs))
# 将图片和标签合并为矩阵
temp = np.array([image_list, label_list])
temp = temp.transpose() # 转置
# 乱序
np.random.shuffle(temp)
# 返回乱序后图片路径和标签
image_list = list(temp[:, 0])
label_list = list(temp[:, 1])
label_list = np.asarray(label_list).astype(np.int) # 整数标签
self.image_list = image_list
self.label_list = label_list
label_list = np.eye(2)[label_list] # 转为one-hot矩阵
return image_list, label_list
def load_and_preprocess_image(self, path):
image = tf.io.read_file(path)
image = tf.image.decode_jp