踩坑
不同h5文件来保存,不管是用save和load,虽然写法比较简单,但是经常遇到各种各样的毛病。还有save_weights和load_weights,尽管能用,但是仍然存在问题
最后发现一种最好用的方法
保存
Model1 = Net1()
scce = tf.keras.losses.SparseCategoricalCrossentropy()
adam = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=10e-8, amsgrad=False, name="Aadm")
Model1.compile(optimizer=adam,loss=scce, metrics=['accuracy'])
checkpoint_save_path = './/model1//Model1.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
Model1.fit(k_x_train,k_y_train,batch_size=64,epochs=30,callbacks=cp_callback)
predictions_nn1 = Model1.predict(X_test)
加载
model1 = net.Net1()
scce = tf.keras.losses.SparseCategoricalCrossentropy()
adam = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=10e-8, amsgrad=False, name="Aadm")
model1.compile(optimizer=adam,loss=scce, metrics=['accuracy'])
checkpoint_save_path = './/Model123.ckpt'
model1.load_weights(checkpoint_save_path)
predictions_nn = model1.predict(X_test)