一:自定义数据处理函数
#txt:txt文件里面存放图像名称
def read_images(txt, batch_size):
imagepaths, labels = list(), list() #分别存放图像路径和label
images_names = []
with open('/data/' + txt, 'r') as r:
images_names.extend(r.readlines())
for name in images_names:
name = name.replace('\n', '')
imagepaths.append(os.path.join(cover_path, name))
labels.append(0)
imagepaths.append(os.path.join(stego_path, name))
labels.append(1)
# Convert to Tensor
imagepaths = tf.convert_to_tensor(imagepaths, dtype=tf.string)
labels = tf.convert_to_tensor(labels, dtype=tf.int32)
# Build a TF Queue, shuffle data
image, label = tf.train.slice_input_producer([imagepaths, labels],
shuffle=True)
# Read images from disk
image = tf.read_file(image)
image = tf.image.decode_jpeg(image, channels=CHANNELS)
# Resize images to a common size
image = tf.image.resize_images(image, [IMG_HEIGHT, IMG_WIDTH])