在使用Keras时 ,加载模型时候可能会存在这个问题 ,那是什么原因遭成的了,可能是 使用了
multi_gpu_model(),保存的是一个multi gpu 的模型 那我如何避免该错误了,
from keras.utils import multi_gpu_model
ori_model = build_model()
model= multi_gpu_model(ori_model, gpus=2)
model.load_weights('best_model.h5')
那问题有来了 如何 将multi_gpu_model 在单个gpu在运行?
解决方案是
ori_model = build_model()
model= multi_gpu_model(ori_model, gpus=2)
model.load_weights('best_model.h5')
model.compile(loss='sparse_categorical_crossentropy',optimizer=RAdam(lr=0.0001))
ori_model.save_weights("single_gpu.h5")
# 再去加载模型
ori_model.load_weights("single_gpu.h5")