我的模型有两组数据输入,所以没法直接使用组装层来实现,使用keras.Model 由于实在是不熟悉tf 与keras,遇到了种种困难包括但不限于网络中的训练参数怎么添加、多输入的处理、Model.build的定义失败、add_weight不会用、网络summary没有参数,最后发现了一个简单的方法,真是结了燃眉之急
class MYM(keras.Model):
def __init__(self):
ipt1 = keras.Input(shape=(13,13,7),name="view")
ipt2 = keras.Input(shape=(34),name="feature")
x = layers.Conv2D(7,kernel_size=3)(ipt1)
x = layers.Conv2D(1,kernel_size=3)(x)
x = layers.Flatten()(x)
#print(x,ipt2)
x =tf.concat([x,ipt2],axis=-1)
x = layers.Dense(128)(x)
#print(x)
x = layers.Dense(21)(x)
out = layers.Softmax(axis=-1)(x)
super(MYM,self).__init__(inputs=[ipt1,ipt2],outputs=out)
summary
net = MYM()
net.summary()