深度学习笔记---SeNet

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/haojiefenglang/article/details/79971626

SENet学习笔记(一)

根据近期的学习情况,将Tensorflow深度学习框架下的Senet实现过程做一个简单的整理,初次学习,还有很多地方理解的不太清楚,欢迎大家一起交流、学习。

Senet程序参数说明

(1) 权重衰减:weight_decay=0.0005
(2) 动量:momentum = 0.9
(3) 初始学习速率:init_learning_rate = 0.1
(4) cardinality = 8 # how many split ?
(5) Res块数目:blocks = 3
(6) 深度(通道数\输出结果个数):depth = 64
(7) 降低倍数:reduction_ratio = 4(论文里是16)
(8) 批处理的大小:batch_size = 128
(9) 迭代次数:iteration = 391
(10) 测试迭代次数:test_iteration = 10
(11) 总的训练轮数:total_epochs = 100

2、卷积层定义:

(1)参数:输入、滤波器个数、卷积核大小、步长、填充类型、层名字
(2)Padding类型分为两种:
padding=’SAME’或‘VAILD’(默认是VAILD)
区别:‘SAME’当剩下的不足卷积核大小时,进行补零操作;‘VAILD’方式使直接把剩余部分丢失。
(3)tf.layers.conv2d()二维卷积函数
use_bias=False或者bias_initializer=None都表示禁用bias,
参数说明表如下:
这里写图片描述
(4)全局平均池化、平均池化
def Global_Average_Pooling(x):
return global_avg_pool(x, name=’Global_avg_pooling’)
def Average_pooling(x, pool_size=[2,2], stride=2, padding=’SAME’):
return tf.layers.average_pooling2d(inputs=x, pool_size=pool_size,strides=stride, padding=padding)
(5)批量标准化batch_normalization(),经过处理之后可以加速训练速度。
详细参考用法:
https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
函数tf.cond(判别,lambda : ……,lambda : ……)类似if……else语句;
(5)激活函数:Relu()、Sigmoid()
(6)Filer concatenation滤波器混合(就是同一尺寸的,深度叠加)
把经过卷积之后的输出,相同尺寸的图按照深度链接起来;
如:如有2个3×3×4的输入,1个3×3×2的输入,3个3×3×1的输入,可以看到输入中一共有2×4+1×2+3×1 =13个通道那么经过Filter Concatenation后就是3×3×13。

(7)全连接层
这里写图片描述
(8)损失函数\评价模型Evaluate(),得到test_acc,test_losssummary

def Evaluate(sess):
test_acc = 0.0
test_loss = 0.0
test_pre_index = 0
add = 1000

for it in range(test_iteration):
    test_batch_x = test_x[test_pre_index: test_pre_index + add]
    test_batch_y = test_y[test_pre_index: test_pre_index + add]
    test_pre_index = test_pre_index + add

    test_feed_dict = {
        x: test_batch_x,
        label: test_batch_y,
        learning_rate: epoch_learning_rate,
        training_flag: False
    }

    loss_, acc_ = sess.run([cost, accuracy], feed_dict=test_feed_dict)

    test_loss += loss_
    test_acc += acc_

test_loss /= test_iteration # average loss
test_acc /= test_iteration # average accuracy

summary = tf.Summary(value=[tf.Summary.Value(tag='test_loss', simple_value=test_loss),
                            tf.Summary.Value(tag='test_accuracy', simple_value=test_acc)])

return test_acc, test_loss, summary

summary的操作都是对某个tensor产生单个的summary protocol buffer,是一种能被tensorboard读取的格式。
整个图经常需要检测许许多多的值,也就是许多值需要summary operation,一个个去run来启动太麻烦了,tensorflow为我们提供了这个函数:
这里写图片描述
生成的结果是events file。可以通过调用如下看到可视化结果:
这里写图片描述
(9)SE_ResNeXT网络结构建立:
- def init(self, x, training): 初始化

  • def first_layer(self, x, scope): 第一层卷积(卷积、批量标准化、激活函数Relu)

  • def transform_layer(self, x, stride, scope): 第二层卷积(卷积、批量标准化、激活函数Relu)

  • transition_layer(self, x, out_dim, scope): 第二层卷积(卷积、批量标准化)

  • def split_layer(self, input_x, stride, layer_name):分离、合并

  • def squeeze_excitation_layer(self, input_x, out_dim, ratio,
    layer_name):挤压和激励层

  • def residual_layer(self, input_x, out_dim, layer_num,
    res_block=blocks): 残差层(残差块组成)

SE-Net网络多残差层堆叠起来的。
这里写图片描述
这里写图片描述

(10)训练数据(输入x,标签y)、测试数据
(11)变量定义
这里写图片描述
(11)训练模型—-计算损失cost—-优化器—-最小化cost+l2_loss达到—-准确率
优化器:https://blog.csdn.net/xierhacker/article/details/53174558

# image_size = 32, img_channels = 3, class_num = 10 in cifar10
x = tf.placeholder(tf.float32, shape=[None, image_size, image_size, img_channels])
label = tf.placeholder(tf.float32, shape=[None, class_num])

training_flag = tf.placeholder(tf.bool)


learning_rate = tf.placeholder(tf.float32, name='learning_rate')

logits = SE_ResNeXt(x, training=training_flag).model
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label, logits=logits))

l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum, use_nesterov=True)
train = optimizer.minimize(cost + l2_loss * weight_decay)

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

saver = tf.train.Saver(tf.global_variables())

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model')
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())

    summary_writer = tf.summary.FileWriter('./logs', sess.graph)

    epoch_learning_rate = init_learning_rate
    for epoch in range(1, total_epochs + 1):
        if epoch % 30 == 0 :
            epoch_learning_rate = epoch_learning_rate / 10

        pre_index = 0
        train_acc = 0.0
        train_loss = 0.0

        for step in range(1, iteration + 1):
            if pre_index + batch_size < 50000:
                batch_x = train_x[pre_index: pre_index + batch_size]
                batch_y = train_y[pre_index: pre_index + batch_size]
            else:
                batch_x = train_x[pre_index:]
                batch_y = train_y[pre_index:]

            batch_x = data_augmentation(batch_x)

            train_feed_dict = {
                x: batch_x,
                label: batch_y,
                learning_rate: epoch_learning_rate,
                training_flag: True
            }

            _, batch_loss = sess.run([train, cost], feed_dict=train_feed_dict)
            batch_acc = accuracy.eval(feed_dict=train_feed_dict)

            train_loss += batch_loss
            train_acc += batch_acc
            pre_index += batch_size


        train_loss /= iteration # average loss
        train_acc /= iteration # average accuracy

        train_summary = tf.Summary(value=[tf.Summary.Value(tag='train_loss', simple_value=train_loss),
                                          tf.Summary.Value(tag='train_accuracy', simple_value=train_acc)])

        test_acc, test_loss, test_summary = Evaluate(sess)

        summary_writer.add_summary(summary=train_summary, global_step=epoch)
        summary_writer.add_summary(summary=test_summary, global_step=epoch)
        summary_writer.flush()

        line = "epoch: %d/%d, train_loss: %.4f, train_acc: %.4f, test_loss: %.4f, test_acc: %.4f \n" % (
            epoch, total_epochs, train_loss, train_acc, test_loss, test_acc)
        print(line)

        with open('logs.txt', 'a') as f:
            f.write(line)

        saver.save(sess=sess, save_path='./model/ResNeXt.ckpt')
阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页