**
<一>准备工作!!!
**
需要下载的东西如下:
数据集:17flowers数据集
预训练权重模型:vgg16.npy 提取码:66im
我的做法是从17flowers数据集里面的17个分类的文件夹中,依次分别cut出5张图片作为测试集。
<二>训练具体流程及细节说明
1.首先我们要将17flowers数据集转换成TFRecord的形式(tfrecord.py)
这是因为TFRecord格式是为Tensorflow打造的一种非常高效的数据读取方式,在了解TFRcord 的过程中,楼主看到了超多优秀的资料!比如:你可能无法回避的 TFRecord 文件格式详细讲解
这里给出
// An highlighted block
import os
import tensorflow as tf
from PIL import Image
def creat_tf(imgpath):
cwd = os.getcwd()
classes = os.listdir(cwd + imgpath)
# 定义tfrecords文件存放
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd + imgpath + name + "/"
print(class_path)
if os.path.isdir(class_path):
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((224, 224))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
print(img_name)
writer.close()
def read_example():
# 读取
for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
example = tf.train.Example()
example.ParseFromString(serialized_example)
# image = example.features.feature['img_raw'].bytes_list.value
label = example.features.feature['label'].int64_list.value
print(label)
if __name__ == '__main__':
imgpath = '/17flowers/'
creat_tf(imgpath)
2.构建VGG16的模型(VGG16.py)
import tensorflow as tf
import numpy as np
# 加载预训练模型
data_dict = np.load('./vgg16.npy', encoding='latin1').item()
# 打印每层信息
def print_layer(t):
print(t.op.name, ' ', t.get_shape().as_list(), '\n')
# 定义卷积层
def conv(x, d_out, name, fineturn=False, xavier=False):
d_in = x.get_shape()[-1].value
with tf.name_scope(name) as scope:
# Fine-tuning
if fineturn:
kernel = tf.constant(data_dict[name][0], name="weights")
bias = tf.constant(data_dict[name][1], name="bias")
print("fineturn")
elif not xavier:
kernel = tf.Variable(tf.truncated_normal