用VGG 16训练数据集并保存模型测试效果+Tensorboard可视化收敛曲线

**

<一>准备工作!!!

**
需要下载的东西如下:
数据集:17flowers数据集
预训练权重模型:vgg16.npy 提取码:66im
我的做法是从17flowers数据集里面的17个分类的文件夹中,依次分别cut出5张图片作为测试集。

<二>训练具体流程及细节说明

1.首先我们要将17flowers数据集转换成TFRecord的形式(tfrecord.py

这是因为TFRecord格式是为Tensorflow打造的一种非常高效的数据读取方式,在了解TFRcord 的过程中,楼主看到了超多优秀的资料!比如:你可能无法回避的 TFRecord 文件格式详细讲解

这里给出

// An highlighted block
import os
import tensorflow as tf
from PIL import Image

def creat_tf(imgpath):
    cwd = os.getcwd()
    classes = os.listdir(cwd + imgpath)

    # 定义tfrecords文件存放
    writer = tf.python_io.TFRecordWriter("train.tfrecords")
    for index, name in enumerate(classes):
        class_path = cwd + imgpath + name + "/"
        print(class_path)
        if os.path.isdir(class_path):
            for img_name in os.listdir(class_path):
                img_path = class_path + img_name
                img = Image.open(img_path)
                img = img.resize((224, 224))
                img_raw = img.tobytes()
                example = tf.train.Example(features=tf.train.Features(feature={
   
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
                writer.write(example.SerializeToString())
                print(img_name)
    writer.close()


def read_example():
    # 读取
    for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
        example = tf.train.Example()
        example.ParseFromString(serialized_example)

        # image = example.features.feature['img_raw'].bytes_list.value
        label = example.features.feature['label'].int64_list.value
        print(label)

if __name__ == '__main__':
    imgpath = '/17flowers/'
    creat_tf(imgpath)

2.构建VGG16的模型(VGG16.py


import tensorflow as tf
import numpy as np

# 加载预训练模型
data_dict = np.load('./vgg16.npy', encoding='latin1').item()


# 打印每层信息
def print_layer(t):
    print(t.op.name, ' ', t.get_shape().as_list(), '\n')


# 定义卷积层
def conv(x, d_out, name, fineturn=False, xavier=False):
    d_in = x.get_shape()[-1].value
    with tf.name_scope(name) as scope:
        # Fine-tuning
        if fineturn:
            kernel = tf.constant(data_dict[name][0], name="weights")
            bias = tf.constant(data_dict[name][1], name="bias")
            print("fineturn")
        elif not xavier:
            kernel = tf.Variable(tf.truncated_normal
  • 2
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
若你要使用 VGG16 模型进行训练数据集,步骤如下: 1. 准备数据集,将数据集分成训练集、验证集和测试集。 2. 对数据进行预处理,如图像尺寸调整、图像增强等。 3. 加载 VGG16 模型,并进行微调。可以选择冻结一部分层的权重,只训练最后几层的权重。 4. 编译模型,选择合适的损失函数、优化器等。 5. 使用模型训练集进行训练,并在验证集上进行验证,调整模型超参数。 6. 对测试集进行测试,评估模型性能。 下面是一个简单的示例代码,展示了如何使用 VGG16 模型对 CIFAR-10 数据集进行训练测试: ```python import tensorflow as tf from tensorflow.keras.datasets import cifar10 from tensorflow.keras.applications.vgg16 import VGG16 from tensorflow.keras.layers import Dense, Flatten from tensorflow.keras.models import Model # 加载 CIFAR-10 数据集 (x_train, y_train), (x_test, y_test) = cifar10.load_data() # 数据预处理 x_train = tf.keras.applications.vgg16.preprocess_input(x_train) x_test = tf.keras.applications.vgg16.preprocess_input(x_test) # 加载 VGG16 模型,不包括顶层 base_model = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3)) # 冻结模型前面的层 for layer in base_model.layers: layer.trainable = False # 添加新的全连接层 x = Flatten()(base_model.output) x = Dense(256, activation='relu')(x) predictions = Dense(10, activation='softmax')(x) # 构建新的模型 model = Model(inputs=base_model.input, outputs=predictions) # 编译模型 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 训练模型 model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, batch_size=32) # 在测试集上评估模型性能 model.evaluate(x_test, y_test) ``` 需要注意的是,这只是一个示例代码,具体情况需要根据数据集的特点进行调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值