tensorflow模型选择采用ckpt文件格式进行保存的时候,会有四个文件:
checkpoint
3D_unet18001.ckpt.meta
3D_unet18001.ckpt.index
3D_unet18001.ckpt.data-00000-of-00001
把这四个看做一个整体即可,不用单独处理。
aim_shape = (128,128,96)
def prediction(model_path, img_path, save_path):
x = tf.placeholder(tf.float32, shape=[1, aim_shape[0], aim_shape[1], aim_shape[2], 1], name='x')
net = a_net_all(x, classes=2) # 网络模型
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)
img_names = sorted(os.listdir(img_path))
saver = tf.train.Saver()
saver.restore(sess, model_path)
for img_name in img_names:
print(img_name)
input_x = nib.load(os.path.join(img_path, img_name)).get_data() #读取nii文件
x_batch = input_x[np.newaxis, :, :, :, np.newaxis]
print(x_batch.shape)
y = sess.run(net, feed_dict={x: x_batch})
# 以下均是一些后处理
result = y.astype(np.uint8)
result = np.squeeze(result)
result = result.transpose(2, 1, 0)
img = itk.GetImageFromArray((result))
itk.WriteImage(img, save_path + str(img_name))
print(save_path+img_name)
在调用prediction函数的时候
model_path = ...model/3D_unet18001.ckpt
# 选择你想要用来预测的模型即可
需要注意的是3D_unet18001.ckpt这个文件并不存在,但是不影响调用,这是tensorflow版本更行之后出现的问题。