resnet__残差神经网络搭建

# -*- coding: utf-8 -*-
import tensorflow as  tf
from collections import namedtuple
from math import sqrt

def print_activations(t):
    print(t.op.name,t.get_shape().as_list())

def conv2d(x,n_fliters,k_h=5,k_w=5,
           stride_h=2,stride_w=2,
           stddev=0.02,activation=lambda x:x,
           bias=True,padding='SAME',name="Conv2D"):
    with tf.variable_scope(name):
        w=tf.get_variable('weight',[k_h,k_w,x.get_shape()[-1],n_fliters],
                          initializer=tf.truncated_normal_initializer(stddev=stddev))
        tf.summary.histogram(name+'weight',w)
        conv=tf.nn.conv2d(x,w,strides=[1,stride_h,stride_w,1],padding=padding)
        if bias:
            b=tf.get_variable('bias',[n_fliters],
                              initializer=tf.truncated_normal_initializer(stddev=stddev))
            tf.summary.histogram(name+'bias',b)
            conv=conv+b
        print_activations(conv)
        return activation(conv)

def linear(x,n_units,scope=None,stddev=0.02,
           activation=tf.identity):
    shape=x.get_shape().as_list()
    with tf.variable_scope(scope or "linear"):
        weight =tf.get_variable("weight",[shape[1],n_units],tf.float32,
                                tf.random_normal_initializer(stddev=stddev))
        tf.summary.histogram('weight',weight)
        bias=tf.get_variable('bias',[n_units],tf.float32,tf.random_normal_initializer(stddev=stddev))
        tf.summary.histogram(tf.matmul(x,weight)+bias)

def ResNet(x,n_outputs,activation=tf.nn.relu):
    LayerBlock=namedtuple('LayerBlock',['num_repeats','num_fiters','bottleneck_size']) #创建Block的类只包含数据结构,不包含具体方法。
    blocks=[LayerBlock(3,128,32),
            LayerBlock(3,256,64),
            LayerBlock(3,512,128),
            LayerBlock(3,1024,256),
            LayerBlock(3,2048,512),
            LayerBlock(3,4096,1024)]
    input_shape=x.get_shape().as_list()
    if len(input_shape)==2:
        ndim=int(sqrt(input_shape[1]))
        if ndim*ndim !=input_shape[1]:
            raise ValueError('input_shape should be square')
        x=tf.reshape(x,[-1,ndim,ndim,1])
        tf.summary.image('input',x,10)
    net=conv2d(x,64,k_h=7,k_w=7,name='conv1',activation=activation) #第一卷积扩展到64个信道和下采样

    net=tf.nn.max_pool(net,[1,2,2,1],strides=[1,2,2,1],padding='SAME')

    print_activations(net)

    net=conv2d(net,blocks[0].num_fiters,k_h=1,k_w=1,
               stride_h=1,stride_w=1,padding='VAlID',name='conv2')  #建设残差神经网络

    for blocks_i,block in enumerate(blocks):     #循环 res blocks
        for repeat_i in  range(block.num_repeats):
            name='block_%d/repeat_%d'%(blocks_i,repeat_i)
            conv=conv2d(net,block.bottleneck_size,k_h=1,k_w=1,
                        padding='VALID',stride_h=1,stride_w=1,
                        activation=activation,name=name+'/conv_in')
            conv=conv2d(conv,block.bottleneck_size,k_h=3,k_w=3,
                        padding='VALID',stride_h=1,stride_w=1,
                        activation=activation,
                        name=name+'/conv_bottleneck')
            conv=conv2d(conv,block.num_fiters,k_h=1,k_w=1,
                        padding='VALID',stride_h=1,stride_w=1,
                        activation=activation,
                        name=name+'/conv_out')
            net=conv+net

        try:
            next_block=blocks[blocks_i+1]

            net=conv2d(net,next_block.num_fiters,k_h=3,k_w=3,
                       padding='SAME',stride_h=1,stride_w=1,
                       name='blcok_%d/conv_upscale' % blocks_i)
        except IndexError:
            pass
    net=tf.nn.avg_pool(net,ksize=[1,net.get_shape().as_list()[1],net.get_shape().as_list()[2],1],
                           strides=[1,1,1,1],padding='VALID')
    print_activations(net)
    net=tf.reshape(net,[-1,net.get_shape().as_list()[1]*net.get_shape().as_list()[2],1],
                   strides=[1,1,1,1],padding='VALID')
    print_activations(net)

    net=linear(net,n_outputs)

    return net
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值