针对tensorflow2.0 在训练模型中添加了自定义类后不能保存为.h5的例子(config的问题)

针对tensorflow2.0 在训练模型中添加了自定义类后不能保存为.h5的例子, 修改方法主要采用引入一个def get_config(self)的子类,
先继承sequential的config,然后采用字典进行保存对应的self的输入变量即可:
本帖子很明晰那输入时in_fan和out_fan,在def get_config(self)的子类中,需要进行updata这两个数据,然后保存为字典格式
#在训练的模型中必须定义config

class Skip_con(tf.keras.layers.Layer):
    def __init__(self,  in_fan, out_fan,**kwargs):
        super(Skip_con,self).__init__()
        self.in_fan=in_fan
        self.out_fan=out_fan
        self.liner=tf.keras.layers.Dense(out_fan,use_bias=False,activation='relu')
        self.transform=tf.keras.layers.Dense(in_fan,use_bias=False)
        self.bn1=tf.keras.layers.BatchNormalization()
        self.bn2=tf.keras.layers.BatchNormalization()
    def call(self,x,**kwargs):
        _x=self.liner(x)
        if self.in_fan==self.out_fan:
            return self.bn1(x+_x)
        elif self.in_fan!=self.out_fan:
            x_=self.transform(_x)
            return self.bn2(x+x_)
    # def get_config(self):
    #     config = {'in_fan': self.in_fan,
    #                    'out_fan': self.out_fan}
    #     base_config=super(Skip_con, self).get_config()
 
    #     return dict(list(base_config.items()) + list(config.items()))
    def get_config(self):
        config = super(Skip_con, self).get_config()
        config.update({'in_fan':self.in_fan,
                       'out_fan':self.out_fan})
        return config

#在调用的过程中,建立新的程序,必须先再次定义类,然后采用custom_objects声明调用的类



```python
class Skip_con(tf.keras.layers.Layer):
    def __init__(self,  in_fan, out_fan,**kwargs):
        super(Skip_con,self).__init__()
        self.in_fan=in_fan
        self.out_fan=out_fan
        self.liner=tf.keras.layers.Dense(out_fan,use_bias=False,activation='relu')
        self.transform=tf.keras.layers.Dense(in_fan,use_bias=False)
        self.bn1=tf.keras.layers.BatchNormalization()
        self.bn2=tf.keras.layers.BatchNormalization()
    def call(self,x,**kwargs):
        _x=self.liner(x)
        if self.in_fan==self.out_fan:
            return self.bn1(x+_x)
        elif self.in_fan!=self.out_fan:
            x_=self.transform(_x)
            return self.bn2(x+x_)
    def get_config(self):
        config = super(Skip_con, self).get_config()
        config.update({'in_fan':self.in_fan,
                       'out_fan':self.out_fan})
        return config
    
model=tf.keras.models.load_model('weinian.h5',custom_objects={'Skip_con':Skip_con}) #这里很重要
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值