多模型相互嵌套调用,图冲突BUG(keras/tf)

前言:在普通的任务中,我们经常使用训练好的模型进行预测推理,单模型的推理任务处理版本的问题应该不会存在其他的BUG。但是在实际的业务处理场景中,往往会使用多个模型共同来处理,一个模型的输出可能就是另一个模型的输入之一或者是另一个模型的数据处理方式。

场景:模型B训练地过程中需要A模型不断地参与

使用错误error: tensorflow.python.framework.errors_impl.InvalidArgument:xxxxxx

1、Younger的处理方式:(适用于模型比较简单,命名较为规范)

直接load两个模型model_A、model_B,比如model_B中的数据预处理需要A模型的参与(或者在B训练的过程中需要A不断地参与),那么在load A之前需要对A进行下注册(load完进行一次推理),相当于使用下A,不然会报错,具体的error可以自己试验下:
eg:

# 伪代码
 model_A  = build_model_A(img_shape, classes=2)
 weight_path_A = './model_A_v1.h5'
 model_A .load_weights(weight_path_A )

 model_A.predict(XXXX)  # 

 model_B  = build_model_B(img_shape, classes=2)
 weight_path_B = './model_B_v1.h5'
 model_B  .load_weights(weight_path_B)
 
 def data_genetor():
 	img = cv2.imread(xxx)
 	img = model_A.predict(img)
 	return img,label
 	
 model_B.fit(data_genetor())
2、Older 处理方式

但是上述的方式不是一个长久之计,还是要规范化处理流程,上述的问题归根到底就是tensorflow-graoh的问题,所以解决问题的原理就是建立不同的session,建立不同的graph,分别在不同的graph下执行各自的推理或者训练,大家互不干扰,通过使用with语句进行图的使用和退出。
下面是伪代码
eg:

g1=tf.Graph()      # get_default_graph()
sess1 = tf.Session(graph=g1)
with sess1.as_default():
    with g1.as_default():
        model_A  = build_model_A(img_shape, classes=2)
        weight_path_2 = './model_A_v1.h5'
        model_A.load_weights(weight_path_2)
        
g2=tf.Graph()        
sess2 = tf.Session(graph=g2)
with sess2.as_default():
    with g2.as_default():
        model_B  = build_model_B(img_shape, classes=2)
 		model_B.fit(data_genetor())
 		
def data_genetor():
 	img = cv2.imread(xxx)
 	with sess1.as_default():
    	with g1.as_default():
 			img = model_A.predict(img)
 	return img,label
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值