Cannot use TimeDistributed with some Model

项目场景:

视频分类项目用到ShuffleNetV2提取每帧图的特征,对于多帧连续图同时提取特征,需要TimeDistributed层包装ShuffleNetV2模型层,此时本人遇到包装问题,问题如下描述。


问题描述

    feature_layer = TimeDistributed(ShuffleNet_V2(input_shape=[24, 32, CHANNEL_NUM], output_num=FEATURE_SIZE))(org_input)
  File "/home/ml/.local/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 765, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/home/ml/.local/lib/python3.8/site-packages/keras/layers/wrappers.py", line 271, in call
    output_shape = self.compute_output_shape(input_shape)
  File "/home/ml/.local/lib/python3.8/site-packages/keras/layers/wrappers.py", line 190, in compute_output_shape
    child_output_shape = self.layer.compute_output_shape(child_input_shape)
  File "/home/ml/.local/lib/python3.8/site-packages/keras/engine/functional.py", line 470, in compute_output_shape
    layer_output_shapes = layer.compute_output_shape(layer_input_shapes)
  File "/home/ml/.local/lib/python3.8/site-packages/keras/engine/base_layer_v1.py", line 559, in compute_output_shape
    raise NotImplementedError
NotImplementedError


原因分析:

TimeDistributed层包装特征提取层ShuffleNet后,无法计算输出的shape


解决方案:

以下两方法都是定义特征提取模型层后,自主计算特征提取模型层的输出shape

方法一:

Output_shape excludes the batch dimension

layer = hub.KerasLayer(“https://tfhub.dev/some/model/1”, output_shape=(4, 4))

Input_shape includes the batch dimension (can be None)

expected_shape = layer.compute_output_shape(input_shape=(10, 2, 4)).as_list()

方法二:

Load KerasLayer

input = …
layer = hub.KerasLayer(“https://tfhub.dev/some/model/1”,)
net = layer(input)
net = tf.keras.Model(input , net)

Add custom compute_output_shape

net.compute_output_shape = lambda x : (x[0],x[1],512)

Create model

model = tf.keras.Sequential([tf.keras.layers.TimeDistributed(net),…])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值