Keras中Lambda自定义层的保存
前言
记录:如何在Keras同时使用Lambda自定义层以及多进程训练
一、Lambda层如何保存
1.Lambda层定义
Lambda层可以自定义函数(如Truncation)以及函数传参(如clip_value_min, clip_value_max, zero_thld)来构建网络层,代码示例:
L1_1 = Lambda(lambda x: Truncation(x,clip_value_min,clip_value_max,zero_thld),name='max_min_truncat_{:d}'.format(truncat_cnt))(L1_1);
truncat_cnt+=1
1.Lambda层保存
Lambda一般使用keras.Model.save()函数来存储为’.h5’文件,但要注意该函数只会保存全局import的函数或者类,如果保存import在函数体内部的函数或者类,load_model时会报错:‘str’ object is not callable,代码如下(示例):
#-----------User Files-------------#
from get_dataset_utils import get_dataset_pp
from get_info_utils import Get_An_Na_Acc
from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights
def Tfun_Train_LY(start_thld,end_thld):
# from user_def_utils import Truncation (will cause error)
...
model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train))
print('save model successfully!')
...
2.多进程训练
注意对于keras这样的训练库一定要定义在目标进程函数体内,否则会因为多进程访问同样keras对象而报错,代码如下(示例):
#-----------User Files-------------#
from get_dataset_utils import get_dataset_pp
from get_info_utils import Get_An_Na_Acc
from user_def_utils import Truncation,DDQC,MDQC,DC_MDQC,OUT_MDQC,Activate_Quantization,Get_Quant_Weights,Get_Quant_Model_Weights
def Tfun_Train_LY(start_thld,end_thld):
# from user_def_utils import Truncation (will cause error)
...
model.save(model_dir+'/save_model/ecg-layer2-echo{:0>3d}.h5'.format((i+1)*eval_echo_after_train))
print('save model successfully!')
...
import multiprocessing
import os
if __name__ == '__main__':
min_tar_num = 0
max_tar_num = 29
process_total = 15 # 设置起飞并行度参数
per_process_num = round((max_tar_num-min_tar_num + 1) / process_total)
process_remain_tar_num = (max_tar_num-min_tar_num + 1) - per_process_num * (process_total - 1)
#Tfun_Threshold_Acc(min_tar_num, max_tar_num)
process_list = []
# 创建进程
try:
for i in range(process_total-1):
process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (i*per_process_num, (i+1)*per_process_num)))
process_list[i].start()
if process_total == 1:
process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = (min_tar_num, max_tar_num+1)))
process_list[0].start()
else:
process_list.append(multiprocessing.Process(target = Tfun_Train_LY, args = ((i+1)*per_process_num, (i+1)*per_process_num+process_remain_tar_num)))
process_list[i+1].start()
# 等待所有进程结束
for prs in process_list:
prs.join()
except:
print ("Error: 无法启动线程")
总结
以上就是今天要讲的内容,本文仅仅简单介绍了Keras中Lambda自定义层保存和多进程训练,仅供参考。