前言:在普通的任务中,我们经常使用训练好的模型进行预测推理,单模型的推理任务处理版本的问题应该不会存在其他的BUG。但是在实际的业务处理场景中,往往会使用多个模型共同来处理,一个模型的输出可能就是另一个模型的输入之一或者是另一个模型的数据处理方式。
场景:模型B训练地过程中需要A模型不断地参与
使用错误error: tensorflow.python.framework.errors_impl.InvalidArgument:xxxxxx
1、Younger的处理方式:(适用于模型比较简单,命名较为规范)
直接load两个模型model_A、model_B,比如model_B中的数据预处理需要A模型的参与(或者在B训练的过程中需要A不断地参与),那么在load A之前需要对A进行下注册(load完进行一次推理),相当于使用下A,不然会报错,具体的error可以自己试验下:
eg:
# 伪代码
model_A = build_model_A(img_shape, classes=2)
weight_path_A = './model_A_v1.h5'
model_A .load_weights(weight_path_A )
model_A.predict(XXXX) #
model_B = build_model_B(img_shape, classes=2)
weight_path_B = './model_B_v1.h5'
model_B .load_weights(weight_path_B)
def data_genetor():
img = cv2.imread(xxx)
img = model_A.predict(img)
return img,label
model_B.fit(data_genetor())
2、Older 处理方式
但是上述的方式不是一个长久之计,还是要规范化处理流程,上述的问题归根到底就是tensorflow-graoh的问题,所以解决问题的原理就是建立不同的session,建立不同的graph,分别在不同的graph下执行各自的推理或者训练,大家互不干扰,通过使用with语句进行图的使用和退出。
下面是伪代码
eg:
g1=tf.Graph() # get_default_graph()
sess1 = tf.Session(graph=g1)
with sess1.as_default():
with g1.as_default():
model_A = build_model_A(img_shape, classes=2)
weight_path_2 = './model_A_v1.h5'
model_A.load_weights(weight_path_2)
g2=tf.Graph()
sess2 = tf.Session(graph=g2)
with sess2.as_default():
with g2.as_default():
model_B = build_model_B(img_shape, classes=2)
model_B.fit(data_genetor())
def data_genetor():
img = cv2.imread(xxx)
with sess1.as_default():
with g1.as_default():
img = model_A.predict(img)
return img,label