[keras] 多GPU运行设置/固定权重

####1.所需要的库
from keras.utils import multi_gpu_model
参考keras官方文档multi-gpu
keras.utils.multi_gpu_model(model, gpus=None, cpu_merge=True, cpu_relocation=False)

2.模型包装
 model = Model(inputs=input_rgb,outputs=softmax) #定义好的模型
 parallel_model=multi_gpu_model(model,gpus=4)  #用multigpu函数包装一下
3.自定义模型 继承callback类

由于模型在fit的时候,用的是多gpu模型,直接调用自带的callback checkpoint函数来存储模型参数的时候,会报错。
我们在定义模型函数的时候,可以返回两个模型,一个是原始模型,另一个是多GPU模型。
多gpu模型用来训练,调用fit/fit_generator函数,原始模型送入自定义callback保存.h5文件。

#模型定义函数
return parallel_model,model
#调用模型定义函数
self.model,self.original_model = self.atten_lstm()
#将原始模型传入自定义callback进行存储
save_model=MyCbk(rm.original_model)

class MyCbk(Callback):
    def __init__(self, model):
        self.model_to_save = model

    def on_epoch_end(self, epoch, logs=None):
        self.model_to_save.save('epoch_%d.h5' % epoch)
4.继承Modelcheckpoint类
class ParallelModelCheckpoint(ModelCheckpoint):

    def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        self.single_model = model
        super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

    def set_model(self, model):
        super(ParallelModelCheckpoint,self).set_model(self.single_model)

check_point = ParallelModelCheckpoint(single_model ,'best.hd5')

除了 ModelCheckpoint,其他的callback函数是可以正常调用的

    tb = TensorBoard(log_dir=os.path.join('..','..','data', 'logs',dataset,model))
    early_stopper = EarlyStopping(patience=100)
    checkpointer = ModelCheckpoint(filepath=os.path.join('..','..','data', 'checkpoints', dataset,model,data_type +'.{epoch:03d}-{val_loss:.3f}-{val_acc:.3f}.hdf5'),
        verbose=1,
        save_best_only=True)

有个问题,自定义的callback函数,怎么传入val_loss val_acc?

5.继承keras.callbacks.Callback
class CustomModelCheckpoint(Callback):
    def __init__(self, model, path):
        self.model = model
        self.path = path
        self.best_acc = 0
        print "init checkpoint saver"

    def on_epoch_end(self, epoch, logs):
        val_loss = logs['val_loss']
        val_acc =logs['val_acc']
        if val_acc > self.best_acc:
            print("\nValidation acc increased from {} to {}, saving model".format(self.best_acc, val_acc))
            self.model.save_weights(self.path+'/'+'acc_%.3f_loss_%.3f_epoch%d_.hdf5'%(val_acc,val_loss,epoch), overwrite=True)
            self.best_acc = val_acc
            print self.path+'/'+'acc_%.3f_loss_%.3f_epoch%d_.hdf5'%(val_acc,val_loss,epoch)


model.fit(X_train, y_train,
              batch_size=batch_size*G, epochs=nb_epoch, verbose=0, shuffle=True,
              validation_data=(X_valid, y_valid),
              callbacks=[CustomModelCheckpoint(model, '/path/to/save/model.h5')])

用这种方法的好处是可以知道val_loss

固定某些层的权重

for layer in model.layers[:-6]: #除了后6层,其他层的权重固定
            layer.trainable = False
for layer in base_model.layers:  
    layer.trainable = False
   
 ## 可以直接传参数进去
 frozen_layer = Dense(32, trainable=False)
 ## 也可以初始化后再设置trainable =False,但是必须重新compile
x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)

frozen_model = Model(x, y)
# in the model below, the weights of `layer` will not be updated during training
frozen_model.compile(optimizer='rmsprop', loss='mse')

参考
https://blog.csdn.net/u012862372/article/details/80367607
https://blog.csdn.net/z5217632/article/details/80952372

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值