keras.layers.LSTM、Dense等 传递input_shape参数给第一层

@创建于:20210413
@修改于:20210413

1、背景

在keras.layers的Sequential 顺序模型API中,顺序模型是多个网络层的线性堆叠,可以通过将层的列表传递给Sequential 的构造函数。包含的方法和属性有:

  • model.layers 是包含模型网络层的展平列表。
  • model.inputs 是模型输入张量的列表。
  • model.outputs 是模型输出张量的列表。

但是keras.layers的网络层(包括Dense、LSTM、ConvLSTM2D等)没有显式存在Input_shape参数。

  • 为什么要这么设置;
  • 还可以设置哪些其他类似的参数?

2、指定输入数据的尺寸

模型需要知道它所期望的输入的尺寸。出于这个原因,顺序模型中的第一层(只有第一层,因为下面的层可以自动地推断尺寸)需要接收关于其输入尺寸的信息。有几种方法来做到这一点:

  • 传递一个input_shape 参数给第一层。它是一个表示尺寸的元组(一个整数或None 的元组,其中None 表示可能为任何正整数)。在input_shape 中不包含数据的batch 大小。
  • 某些2D 层,例如Dense,支持通过参数input_dim 指定输入尺寸,某些3D 时序层支持input_dim 和input_length 参数。
  • 如果你需要为你的输入指定一个固定的batch 大小(这对stateful RNNs 很有用),你可以传递一个batch_size 参数给一个层。如果你同时将batch_size=32 和input_shape=(6,8) 传递给一个层,那么每一批输入的尺寸就为(32,6,8)。
    因此,下面的代码片段是等价的:
model = Sequential()
model.add(Dense(32, input_shape=(784,)))
model = Sequential()
model.add(Dense(32, input_dim=784))

上面的参考自:
keras-docs-zh-master_text版本 - https://github.com/wanzhenchn/keras-docs-zh

3、核心网络层没有显式的input_shape, input_dim参数,如何传递的?

本人使用的是tensorflow 2.3.0,对应的keras版本是2.3.1(可能是)。
input_shape, input_dim参数是通过**kwargs传递的。
以LSTM为例:

  • LSTM的API接口,继承自recurrent.LSTM;
  • recurrent.LSTM继承自RNN(Layer):里面就有input_shape的来源;
  • RNN(Layer)继承自Layer:里面有可以允许的字典的键。
@keras_export('keras.layers.LSTM', v1=[])
class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):


@keras_export(v1=['keras.layers.LSTM'])
class LSTM(RNN):


@keras_export('keras.layers.RNN')
class RNN(Layer):
    if 'input_shape' not in kwargs and (
        'input_dim' in kwargs or 'input_length' in kwargs):
      input_shape = (kwargs.pop('input_length', None),
                     kwargs.pop('input_dim', None))
      kwargs['input_shape'] = input_shape
      

@keras_export('keras.layers.Layer')
class Layer(module.Module, version_utils.LayerVersionSelector):
allowed_kwargs = {
        'input_dim',
        'input_shape',
        'batch_input_shape',
        'batch_size',
        'weights',
        'activity_regularizer',
        'autocast',
    }

以上来自官方代码。

4、*args与**kwargs使用

4.1 *args的用法

*args就是就是传递一个可变参数列表给函数实参,这个参数列表的数目未知,甚至长度可以为0。下面这段代码演示了如何使用args。

def test_args(first, *args):
    print('Required argument: ', first)
    print(type(args))
    for v in args:
        print ('Optional argument: ', v)

test_args(1, 2, 3, 4)

第一个参数是必须要传入的参数,所以使用了第一个形参,而后面三个参数则作为可变参数列表传入了实参,并且是作为元组tuple来使用的。代码的运行结果如下:

Required argument:  1
<class 'tuple'>
Optional argument:  2
Optional argument:  3
Optional argument:  4

4.2 **kwargs的用法

**kwargs则是将一个可变的关键字参数的字典传给函数实参,同样参数列表长度可以为0或为其他值。下面这段代码演示了如何使用kwargs。

def test_kwargs(first, *args, **kwargs):
   print('Required argument: ', first)
   print(type(kwargs))
   for v in args:
      print ('Optional argument (args): ', v)
   for k, v in kwargs.items():
      print ('Optional argument %s (kwargs): %s' % (k, v))

test_kwargs(1, 2, 3, 4, k1=5, k2=6)

正如前面所说的,args类型是一个tuple,而kwargs则是一个字典dict,并且args只能位于kwargs的前面。代码的运行结果如下:

参考自:Python中的*args和**kwargs

  • 3
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值