1. 网上常见的有bug的代码
from keras.utils import multi_gpu_model
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
model = vgg()
parallel_model = multi_gpu_model(model, gpus=2)
parallel_model.compile(loss='categorical_crossentropy', optimizer='adam')
parallel_model.fit(x, y, epochs=20, batch_size=256)
使用这个代码可以正常训练和验证,但是在保存模型时会报错:
TypeError: can't pickle ... objects
这是因为keras此时默认保存parallel_model
,但是keras一保存parallel_model
就会报这个错误,此时只能保存model
。
当你只需要保存最终的模型时,可以使用以下方法:
model.save('xxx.h5')
但是,大多数人都不会只保存最终的模型吧,起码我是习惯使用checkpoint来保存模型,那就看接下来的小节吧!
2. 正确的姿势
from keras.utils import multi_gpu_model
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
class ParallelModelCheckpoint(ModelCheckpoint):
def __init__(self,model,filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
self.single_model = model
super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)
def set_model(self, model):
super(ParallelModelCheckpoint,self).set_model(self.single_model)
paralle_model = multi_gpu_model(model, gpus=2)
paralle_model.compile(optimizer=Optimizer, loss=Loss, metrics=Metrics)
model_path = "unet-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}-{val_iou:.4f}.hdf5"
model_checkpoint = ParallelModelCheckpoint(model, model_path, monitor='val_loss', mode='min', verbose=1, save_best_only=False)
3. 更深层次的multi gpu parallel的bug
multi_gpu_model doesn't work with stateful models
-
详细来说,keras在内部实现网络层时,都不需要明确指定batch_size,涉及到reshape等操作时,都直接用-1代替,batch_size那个维度会显示None。因此,我们在定义模型时,也不需要明确指定batch_size。
-
如果自己需要自定义一个网络层,需要明确指定batch_size来进行一些操作时,我们可以通过如下方式来明确告诉模型我们的batch_size大小:
inputs = Input(batch_shape=(batch_size,) + (height, width, channel))
但是,一旦通过multi_gpu_model进行模型编译之后,就会报错,原因是,我们给model
明确了batch_size的大小,但是paralle_model
并不知道,因为内部并没有明确指定batch_size被平分给多个gpu,所以,当真正分配到每个gpu时,batch_size又会显示None,此时就会报错。
具体的解决办法我现在还没有尝试,下面先提供几个参考链接:
https://github.com/keras-team/keras/issues/8397
https://github.com/visionscaper/stateful_multi_gpu/blob/master/util.py
https://github.com/kuixu/keras_multi_gpu