【Tensorflow2】tf.keras中Model实例化的方式

最近要用到Tensorflow了,回顾一下。

参考:
使用tf.keras自定义模型建模后model.summary()中Param的计算过程

Model的两种实例化方式

【tf官网】Model实例化方式

在这里插入图片描述

1. 功能性API

def MyModel(input_shape):
    input1 = tf.keras.Input(shape=input_shape,name="input1")
    X = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1")(input1)

    model = tf.keras.Model(inputs=input1,outputs=X,name="my_model") 
    return model

2. 继承tf.keras.Model

class MyModel(tf.keras.Model):
    def __init__(self,input_shape):
        super(MyModel,self).__init__()	# 必须在首行明确
        self.input1 = tf.keras.Input(shape=input_shape,name="input1")

        self.dense1 = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1")

        self.out1 = self.call(self.input1)
        
        # reinitialize
        super(MyModel,self).__init__(
            inputs=self.input1,
            outputs=self.out1,
            name="my_model"
        )
    
    # 前向转播过程
    def call(self,inputs):
        """
        参数:
            input           - 输入,形状必须为 self.input_shape
        """
        x = self.dense1(inputs)
        return x

summary输出

执行以下代码:

if __name__ == '__main__':
    model = MyModel((100,))
    model.summary()

输出如下,

  1. 功能性API
Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input1 (InputLayer)          [(None, 100)]             0
_________________________________________________________________
dense1 (Dense)               (None, 4)                 404
=================================================================
Total params: 404
Trainable params: 404
Non-trainable params: 0
_________________________________________________________________
  1. 继承Model
Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input1 (InputLayer)          [(None, 100)]             0
_________________________________________________________________
dense1 (Dense)               (None, 4)                 404
=================================================================
Total params: 404
Trainable params: 404
Non-trainable params: 0
_________________________________________________________________

可以看到,summary()输出相同。


model.saveload_model

  1. 功能性API
if __name__ == '__main__':
    model = MyModel((100,))
    model.save("mymodel.h5")							# 保存模型
    model = tf.keras.models.load_model("mymodel.h5")	# 加载模型
    model.summary()
  1. 继承Model

加载模型时,需要明确custom_objects

if __name__ == '__main__':
    model = MyModel((100,))
    model.save("mymodel.h5")	# 保存模型
    # 加载模型,需要明确custom_objects
    model = tf.keras.models.load_model("mymodel.h5",custom_objects={"MyModel":MyModel})	
    model.summary()
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值