tensorflow: 模型恢复及使用模型进行测试

tensorflow模型选择采用ckpt文件格式进行保存的时候,会有四个文件:

checkpoint
3D_unet18001.ckpt.meta
3D_unet18001.ckpt.index
3D_unet18001.ckpt.data-00000-of-00001

把这四个看做一个整体即可,不用单独处理。

aim_shape = (128,128,96)

def prediction(model_path, img_path, save_path):

    x = tf.placeholder(tf.float32, shape=[1, aim_shape[0], aim_shape[1], aim_shape[2], 1], name='x')
    net = a_net_all(x, classes=2) # 网络模型

    init = tf.global_variables_initializer()
    sess = tf.InteractiveSession()
    sess.run(init)
    img_names = sorted(os.listdir(img_path))

    saver = tf.train.Saver()
    saver.restore(sess, model_path)
    for img_name in img_names:
        print(img_name)
        input_x = nib.load(os.path.join(img_path, img_name)).get_data() #读取nii文件
        x_batch = input_x[np.newaxis, :, :, :, np.newaxis]
        print(x_batch.shape)
        y = sess.run(net, feed_dict={x: x_batch})
        
        # 以下均是一些后处理
        result = y.astype(np.uint8)
        result = np.squeeze(result)
        
        result = result.transpose(2, 1, 0)
        img = itk.GetImageFromArray((result))
        itk.WriteImage(img, save_path + str(img_name))
        print(save_path+img_name)

在调用prediction函数的时候

model_path = ...model/3D_unet18001.ckpt
# 选择你想要用来预测的模型即可

需要注意的是3D_unet18001.ckpt这个文件并不存在,但是不影响调用,这是tensorflow版本更行之后出现的问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值