基于vgg16的迁移学习,训练自己的数据集(含预测结果)

本文介绍了vgg16网络的基础知识,包括其结构特点和参数优势。通过详细步骤展示了如何运用vgg16进行迁移学习,涉及数据集准备、预训练权重下载、标注文件生成、TFRecord文件制作以及模型训练。经过20万代训练,验证集上的准确率达到90.75%。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.vggNet简介

vgg16是2014年由牛津大学提出的一个深度神经网络模型,该模型在2014年的ILSVRC分类比赛中,取得了第二名的成绩,而第一名当属大名鼎鼎的googleNet,vggNet包含5种网络类型,如下图所示:

常见的有vgg16和vgg19。顾名思义vgg16有16层,包含13层卷积池化层和3层全连接层。而vgg19包含16层卷积池化层和3层全连接层。vggNet全部使用1x1,3x3的卷积核,而且vggNet证明了两个3x3的卷积核可以等效为一个5x5的卷积核,下图示

                       


一张5x5的图经两个3x3的卷积核卷积后得到一张1x1的特征图,等效为一个5x5的卷积核。同时在参数量上可以发现,5x5的卷积核的参数量是5x5=25,两个3x3的卷积核是2x3x3=18,参数量是减少了的28%,同时由于与一个5x5的卷积核卷积只需一次非线性激活,而与两个卷积核卷积可以进行两次非线性激活变换,非线性表征加强了,增加了CNN对特征的学习能力。另外1x1卷积核能实现降维,增加非线性。


2.vgg16实现迁移学习

1.数据集准备,我使用8类数据,分别是truck,tiger,flower,kittycat,guitar,houses,plane,person,数据每类训练集500张,验证集300张

2.vgg16预训练权重下载,我把它放在我的百度网盘里了,密码fwi4

3.生成train.txt,val.txt,label.txt

create_labels_files.py

# -*-coding:utf-8-*-

import os
import os.path

def write_txt(content, filename, mode='w'):
    """保存txt数据
    :param content:需要保存的数据,type->list
    :param filename:文件名
    :param mode:读写模式:'w' or 'a'
    :return: void
    """
    with open(filename, mode) as f:
        for line in content:
            str_line = ""
            for col, data in enumerate(line):
                if not col == len(line) - 1:
                    # 以空格作为分隔符
                    str_line = str_line + str(data) + " "
                else:
                    # 每行最后一个数据用换行符“\n”
                    str_line = str_line + str(data) + "\n"
            f.write(str_line)


def get_files_list(dir):
    '''
    实现遍历dir目录下,所有文件(包含子文件夹的文件)
    :param dir:指定文件夹目录
    :return:包含所有文件的列表->list
    '''
    # parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
    files_list = []
    for parent, dirnames, filenames in os.walk(dir):
        for filename in filenames:
            print("parent is: " + parent)
            print("filename is: " + filename)
            # print(os.path.join(parent, filename))  # 输出rootdir路径下所有文件(包含子文件)信息
            curr_file = parent.split(os.sep)[-1]
            if curr_file == 'flower':
                labels = 0
            elif curr_file == 'guitar':
                labels = 1
            elif curr_file == 'person':
                labels = 2
            elif curr_file == 'houses':
                labels = 3
            elif curr_file == 'plane':
                labels = 4
            elif curr_file == 'tiger':
                labels = 5
            elif curr_file == 'kittycat':
                labels = 6
            elif curr_file == 'truck':
                labels = 7
            files_list.append([os.path.join(curr_file, filename), labels])
            print(files_list)
    return files_list


if __name__ == '__main__':
    train_dir = 'dataset/train'
    train_txt = 'dataset/train.txt'
    train_data = get_files_list(train_dir)
    write_txt(train_data, train_txt, mode='w')

    val_dir = 'dataset/val'
    val_txt = 'dataset/val.txt'
    val_data = get_files_list(val_dir)
    write_txt(val_data, val_txt, mode='w')

4.制作tf.record文件

create_tf_record.py

# -*-coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random
from PIL import Image

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成实数型的属性
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def get_example_nums(tf_records_filenames):
    '''
    统计tf_records图像的个数(example)个数
    :param tf_records_filenames: tf_records文件路径
    :return:
    '''
    nums= 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums += 1
    return nums

def show_image(title,image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')    # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()

def load_labels_file(filename,labels_num=1,shuffle=False):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,如:test_image/1.jpg 0 2
    :param filename:
    :param labels_num :labels个数
    :param shuffle :是否打乱顺序
    :return:images type->list
    :return:labels type->list
    '''
    images=[]
    labels=[]
    with open(filename) as f:
        lines_list=f.readlines()
        if shuffle:
            random.shuffle(lines_list)

        for lines in lines_list:
            line=lines.rstrip().split(' ')
            label=[]
            for i in range(labels_nu
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值