NAVIDA Deep Leraning 深度学习 TensorFlow 源码-立哥开发

#Copyright 2008-2021 Jacky Zong .All rights reserved .
# 殿下,
#为你战死是我至高无上的荣耀! -《离思五首》

import tensorflow as tf

from model import layers

__all__ = ['conv2d_block']


def conv2d_block(
    inputs,
    n_channels,
    kernel_size=(3, 3),
    strides=(2, 2),
    mode='SAME',
    use_batch_norm=True,
    activation='relu',
    is_training=True,
    data_format='NHWC',
    conv2d_hparams=None,
    batch_norm_hparams=None,
    name='conv2d',
    cardinality=1,
):

    if not isinstance(conv2d_hparams, tf.contrib.training.HParams):
        raise ValueError("The paramater `conv2d_hparams` is not of type `HParams`")

    if not isinstance(batch_norm_hparams, tf.contrib.training.HParams) and use_batch_norm:
        raise ValueError("The paramater `conv2d_hparams` is not of type `HParams`")

    with tf.variable_scope(name):
        if cardinality == 1:
            net = layers.conv2d(
                inputs,
                n_channels=n_channels,
                kernel_size=kernel_size,
                strides=strides,
                padding=mode,
                data_format=data_format,
                use_bias=not use_batch_norm,
                trainable=is_training,
                kernel_initializer=conv2d_hparams.kernel_initializer,
                bias_initializer=conv2d_hparams.bias_initializer)
        else:
            group_filter = tf.get_variable(
                name=name + 'group_filter',
                shape=[3, 3, n_channels // cardinality, n_channels],
                trainable=is_training,
                dtype=tf.float32)
            net = tf.nn.conv2d(inputs,
                                      group_filter,
                                      strides=strides,
                                      padding='SAME',
                                      data_format=data_format)
        if use_batch_norm:
            net = layers.batch_norm(
                net,
                decay=batch_norm_hparams.decay,
                epsilon=batch_norm_hparams.epsilon,
                scale=batch_norm_hparams.scale,
                center=batch_norm_hparams.center,
                is_training=is_training,
                data_format=data_format,
                param_initializers=batch_norm_hparams.param_initializers
            )

        if activation == 'relu':
            net = layers.relu(net, name='relu')

        elif activation == 'tanh':
            net = layers.tanh(net, name='tanh')

        elif activation != 'linear' and activation is not None:
            raise KeyError('Invalid activation type: `%s`' % activation)

        return net

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值