php keras,Keras实现支持masking的Flatten层代码示例

本篇文章小编给大家分享一下Keras实现支持masking的Flatten层代码示例,代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。

Keras原本Flatten的实现

class Flatten(Layer):

def __init__(self, **kwargs):

super(Flatten, self).__init__(**kwargs)

self.input_spec = InputSpec(min_ndim=3)

def compute_output_shape(self, input_shape):

if not all(input_shape[1:]):

raise ValueError('The shape of the input to "Flatten" '

'is not fully defined '

'(got ' + str(input_shape[1:]) + '. '

'Make sure to pass a complete "input_shape" '

'or "batch_input_shape" argument to the first '

'layer in your model.')

return (input_shape[0], np.prod(input_shape[1:]))

def call(self, inputs):

return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K

from keras.engine.topology import Layer

import tensorflow as tf

import numpy as np

class MyFlatten(Layer):

def __init__(self, **kwargs):

self.supports_masking = True

super(MyFlatten, self).__init__(**kwargs)

def compute_mask(self, inputs, mask=None):

if mask==None:

return mask

return K.batch_flatten(mask)

def call(self, inputs, mask=None):

return K.batch_flatten(inputs)

def compute_output_shape(self, input_shape):

return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *

from keras.models import Model

from MyFlatten import MyFlatten

from MySumLayer import MySumLayer

from keras.initializers import ones

data = [[1,0,0,0],

[1,2,0,0],

[1,2,3,0],

[1,2,3,4]]

A = Input(shape=[4]) # None * 4

emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3

fla = MyFlatten()(emb) # None * 12

out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])

print model.predict(data)

输出:

[ 3. 6. 9. 12.]

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值