针对tensorflow2.0 在训练模型中添加了自定义类后不能保存为.h5的例子, 修改方法主要采用引入一个def get_config(self)的子类,
先继承sequential的config,然后采用字典进行保存对应的self的输入变量即可:
本帖子很明晰那输入时in_fan和out_fan,在def get_config(self)的子类中,需要进行updata这两个数据,然后保存为字典格式
#在训练的模型中必须定义config
class Skip_con(tf.keras.layers.Layer):
def __init__(self, in_fan, out_fan,**kwargs):
super(Skip_con,self).__init__()
self.in_fan=in_fan
self.out_fan=out_fan
self.liner=tf.keras.layers.Dense(out_fan,use_bias=False,activation='relu')
self.transform=tf.keras.layers.Dense(in_fan,use_bias=False)
self.bn1=tf.keras.layers.BatchNormalization()
self.bn2=tf.keras.layers.BatchNormalization()
def call(self,x,**kwargs):
_x=self.liner(x)
if self.in_fan==self.out_fan:
return self.bn1(x+_x)
elif self.in_fan!=self.out_fan:
x_=self.transform(_x)
return self.bn2(x+x_)
# def get_config(self):
# config = {'in_fan': self.in_fan,
# 'out_fan': self.out_fan}
# base_config=super(Skip_con, self).get_config()
# return dict(list(base_config.items()) + list(config.items()))
def get_config(self):
config = super(Skip_con, self).get_config()
config.update({'in_fan':self.in_fan,
'out_fan':self.out_fan})
return config
#在调用的过程中,建立新的程序,必须先再次定义类,然后采用custom_objects声明调用的类
```python
class Skip_con(tf.keras.layers.Layer):
def __init__(self, in_fan, out_fan,**kwargs):
super(Skip_con,self).__init__()
self.in_fan=in_fan
self.out_fan=out_fan
self.liner=tf.keras.layers.Dense(out_fan,use_bias=False,activation='relu')
self.transform=tf.keras.layers.Dense(in_fan,use_bias=False)
self.bn1=tf.keras.layers.BatchNormalization()
self.bn2=tf.keras.layers.BatchNormalization()
def call(self,x,**kwargs):
_x=self.liner(x)
if self.in_fan==self.out_fan:
return self.bn1(x+_x)
elif self.in_fan!=self.out_fan:
x_=self.transform(_x)
return self.bn2(x+x_)
def get_config(self):
config = super(Skip_con, self).get_config()
config.update({'in_fan':self.in_fan,
'out_fan':self.out_fan})
return config
model=tf.keras.models.load_model('weinian.h5',custom_objects={'Skip_con':Skip_con}) #这里很重要