Keras中Lambda自定义层保存和多进程训练


前言

记录:如何在Keras同时使用Lambda自定义层以及多进程训练


一、Lambda层如何保存

1.Lambda层定义

Lambda层可以自定义函数(如Truncation)以及函数传参(如clip_value_min, clip_value_max, zero_thld)来构建网络层,代码示例:

L1_1 = Lambda(lambda x: Truncation(x,clip_value_min,clip_value_max,zero_thld),name='max_min_truncat_{:d}'.format(truncat_cnt))(L1_1);
truncat_cnt+=1

1.Lambda层保存

Lambda一般使用keras.Model.save()函数来存储为’.h5’文件,但要注意该函数只会保存全局import的函数或者类,如果保存import在函数体内部的函数或者类,load_model时会报错:‘str’ object is not callable,代码如下(示例):

#-----------User Files-------------#
from get_dataset_utils import get_dataset_pp
from get_info_utils import Get_An_Na_Acc
from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights

def Tfun_Train_LY(start_thld,end_thld):
	# from user_def_utils import Truncation (will cause error)
	...
	model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train))
	print('save model successfully!')
	...

2.多进程训练

注意对于keras这样的训练库一定要定义在目标进程函数体内,否则会因为多进程访问同样keras对象而报错,代码如下(示例):

#-----------User Files-------------#
from get_dataset_utils import get_dataset_pp
from get_info_utils import Get_An_Na_Acc
from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights

def Tfun_Train_LY(start_thld,end_thld):
	# from user_def_utils import Truncation (will cause error)
	...
	model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train))
	print('save model successfully!')
	...

import multiprocessing
import os

if __name__ == '__main__':
    min_tar_num = 0
    max_tar_num = 29
    process_total = 15 # 设置起飞并行度参数
    per_process_num = round((max_tar_num-min_tar_num + 1) / process_total)
    process_remain_tar_num = (max_tar_num-min_tar_num + 1) - per_process_num * (process_total - 1)

    #Tfun_Threshold_Acc(min_tar_num, max_tar_num)
    
    process_list = []
    # 创建进程
    try:
        for i in range(process_total-1):
            process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (i*per_process_num, (i+1)*per_process_num)))
            process_list[i].start()
        if process_total == 1:
            process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (min_tar_num, max_tar_num+1)))
            process_list[0].start()
        else:
            process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = ((i+1)*per_process_num, (i+1)*per_process_num+process_remain_tar_num)))
            process_list[i+1].start()

        # 等待所有进程结束
        for prs in process_list:
            prs.join()
    except:
        print ("Error: 无法启动线程")

总结

以上就是今天要讲的内容,本文仅仅简单介绍了Keras中Lambda自定义层保存和多进程训练,仅供参考。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值