Tensorflow学习笔记:CNN篇(9)——Finetuning,复用ImageNet的VGGNet进行图像识别

这篇博客介绍了如何在Tensorflow中复用预训练的VGGNet模型进行图像识别。内容包括从npz文件读取权重、定义VGGNet复用类、权重载入以及模型的使用。通过实例展示了如何对模型进行Finetuning并保存为Tensorflow格式。
摘要由CSDN通过智能技术生成

Tensorflow学习笔记:CNN篇(9)——Finetuning,复用在ImageNet已训练好的VGGNet进行图像识别


前序

— 到目前为止,对于模型的设计和训练,读者可能已经较为熟悉,如果读者已经能够使用设计出的模型进行训练并取得较好的结果,那么,恭喜你,你对Tensorflowd程序的编写已经可以说更上了一层台阶。
— 但是在实际工程或者商业使用中,模型的训练并不都是由程序设计人员独立训练,而是通过复用已有的神经网络模型,导入已训练好的权重数据,从而实现图像的分解的目的。
— VGGNet是最常用的深度学习模型,在各种图片分类和更深一步的语义识别、图像分割上都有好的表现,因此作为最常用的深度学习基础模型被大量采用。


代码示例

Step 1: npz文件的读取

对于复用的VGG模型(在imagenet进行训练的模型),首先第一步是要获得相应的权重文件和对应的分类文件,读者可以在以下地址下载相应的文件。
权重文件:http://www.cs.toronto.edu/~frossard/vgg16/vgg16_weights.npz
分类文件:http://www.cs.toronto.edu/~frossard/vgg16/imagenet_classes.py
这里写图片描述
对于下载下的vgg16_weights.npz文件的说明,需要了解npz文件格式是Numpy包中自带的一种专用的二进制文件存储格式,并且Numpy提供了很多种存取其内容的文件操作函数,可通过自带的load函数对其进行载入,之后将其作为字典赋值给vgg_dict变量,而对其读取可以使用类似字典的方式进行。

import numpy as np
vgg_dict = np.load('./vgg16_weights.npz')
print(vgg_dict.keys())
print(vgg_dict["conv1_1_W"])
Step 2: 复用的VGGNet模型定义

对于复用模型来说,最关键的一步就是复用其中已训练好的权重参数,而通过load方式,可以将其中所包含的数据以字典的形式读出,之后根据参数的不同予以载入。
1. 第一步:定义VGGNet的复用类
首先是对于类的定义,前面已经说过,如果想要复用已经训练完毕的权重参数,则需要在模型中将其作为参数输入。而类中输入参数的方式就是将参数在整个类中共享,故在类的初始化中加入一个全局列表,将所需要共享的参数加载至类中。代码如下:

def __init__(self, imgs):
        self.parameters = [] ##关键语句
        self.imgs = imgs
        self.convlayers()
        self.fc_layers()
        self.probs = tf.nn.softmax(self.fc8)

init中定义的参数与自训练的类中的相同,但是多设置了一个parameter列表,其作用就是将各个层产生的数据以列表元素的方式加载至其中。

def conv(self,name, input_data, out_channel):
        in_channel = input_data.get_shape()[-1]
        with tf.variable_scope(name):
            kernel = tf.get_variable("weights", [3, 3, in_channel, out_channel], dtype=tf.float32)
            biases = tf.get_variable("biases", [out_channel], dtype=tf.float32)
            conv_res = tf.nn.conv2d(input_data, kernel, [1, 1, 1, 1], padding="SAME")
            res = tf.nn.bias_add(conv_res, biases)
            out = tf.nn.relu(res, name=name)
        self.parameters += [kernel, biases] ##关键语句
        return out

同样在卷积层方法定义时,在所有参数定义后,需要一个将参数加载到对应的列表中的方法,即使用self.parameters += [kernel,

  • 7
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值