tensorflow-slim模块推广

使用slim模块,快速简洁实现VGG16,并且实现基于VGG16 fine-tune全部fine-tune只fine-tune fc8层之前的层

注意:由于slim模块slim.learning.train()与slim.learning.create_train_op() 不像平常的feed数据(没有占位符),需要slim模块中的tfrecords 生成并直接feed进去。很难理解,因此,我们这里只采用slim模块的一些函数,不采用slim模块的训练。

slim.get_model_variables('vgg_16') 这个函数会自动匹配开头所有符合关键字的 变量

1) 实现vgg16: all initializer randomly:

def vgg16(inputs,is_training):

    network=slim.nets.vgg

    net=network.vgg_16(inputs,1000,is_training=is_training)

    return net

def vgg16(inputs,is_training):
  with slim.arg_scope([slim.conv2d, slim.fully_connected],
                      activation_fn=tf.nn.relu,
                      weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
                      weights_regularizer=slim.l2_regularizer(0.0005)):
    net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
    net = slim.max_pool2d(net, [2, 2], scope='pool1')
    net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
    net = slim.max_pool2d(net, [2, 2], scope='pool2')
    net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
    net = slim.max_pool2d(net, [2, 2], scope='pool3')
    net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
    net = slim.max_pool2d(net, [2, 2], scope='pool4')
    net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
    net = slim.max_pool2d(net, [2, 2], scope='pool5')
    net = slim.fully_connected(net, 4096, scope='fc6')
    net = slim.dropout(net, 0.5, scope='dropout6',is_training=is_training)
    net = slim.fully_connected(net, 4096, scope='fc7')
    net = slim.dropout(net, 0.5, scope='dropout7',is_training=is_training)

    net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')

  return net

2)基于vgg16 fine-tune:

''' vgg16 '''
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim import nets
import numpy as np
import random

def one_hot(labels,class_num):
    N=labels.shape[0]
    one_hot=np.zeros([N,class_num],dtype=np.int32)
    for i in range(N):
        for j in range(class_num):
            one_hot[i,j]=np.int32((labels[i]==j))
    return one_hot
'''
More advanced tf functions: nets (Call already written functions)
'''
sess=tf.Session()
fine_tune_path=r'**/vgg_16.ckpt'
#reader=tf.train.NewCheckpointReader(fine_tune_path)
#key=reader.get_variable_to_shape_map() #查看 权重

class_num=21

# 如果要fine-tune 分类网络 (其中的FC也fine-tune,那么需要输入图片大小也是224*224),

#如果不fine-tune FC那么可以输入任意大小

image=tf.placeholder(tf.float32,shape=[None,224,224,3],name='image')  
label=tf.placeholder(tf.int32,shape=[None,class_num],name='label')
is_training=tf.placeholder(tf.bool,name='is_training')

network=nets.vgg
net,end_points=network.vgg_16(image,class_num,is_training=is_training)# return two value: prediction and end_points
#init_fn=slim.assign_from_checkpoint_fn(fine_tune_path,slim.get_model_variables('vgg_16'))# 这是fine-tune全部权重
#print(net.get_shape().as_list())
#print(end_points.keys())
softmax=slim.nn.softmax(net+tf.constant(1e-4))
pred=tf.argmax(softmax,axis=-1)
tf.add_to_collection('pred',pred)

init_fn=None
# 只获取FC8之前的权重 这三行可以在sess.run(init_op) 之前任意位置 因为是将权重导入sess会话窗口
exclude = ['vgg_16/fc8'] # find it can auto match fc8's weights and biases slim模块的获取变量会自动识别 fc8下的weights biases
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
init_fn = slim.assign_from_checkpoint_fn(fine_tune_path, variables_to_restore)

''' slim.losses.softmaxloss and tf.nn.softmaxloss are not the same ,
    the first is mean and the second need mean '''
loss=slim.losses.softmax_cross_entropy(logits=net,onehot_labels=label)# [batch_size num_class]
''' in the future , the slim.losses will be not used '''
train_op=tf.train.GradientDescentOptimizer(1e-4).minimize(loss)

init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
if init_fn is not None:
    print('ok')
    init_fn(sess)
    print('successful fine-tune')
saver=tf.train.Saver(max_to_keep=1)

batch_size=4
for step in range(1000):
    imgs=np.random.random([batch_size,224,224,3])
    print(imgs.shape)
    labs=one_hot(np.array(random.sample(range(class_num),batch_size)),class_num)
    sess.run(train_op,feed_dict={image:imgs,label:labs,is_training:True})
    Loss=sess.run(loss,feed_dict={image:imgs,label:labs,is_training:True})
    print(Loss)
    if step%100==0:
        saver.save(sess,r'./model/model.ckpt',global_step=step)


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值