使用tf2的saved_model进行推理

import tensorflow as tf
import cv2
from PIL import Image
import numpy as np
import colorsys
import os
import matplotlib.pyplot as plt


def resize_image(image, size):
    """ 等比例resize """
    iw, ih  = image.size
    w, h    = size

    scale   = min(w/iw, h/ih)
    nw      = int(iw*scale)
    nh      = int(ih*scale)

    image   = image.resize((nw,nh), Image.BICUBIC)
    new_image = Image.new('RGB', size, (128,128,128))
    new_image.paste(image, ((w-nw)//2, (h-nh)//2))

    return new_image, nw, nh


def preprocess_input(image):
    image = image / 127.5 - 1
    return image


input_shape = (512,512) # 与训练的时候一致
num_classes = 2 # 类别+1

def preProcessing(filepath):
    inputs = cv2.imread(filepath)
    old_img = Image.open(filepath)
    h,w = inputs.shape[0],inputs.shape[1]
    # print(f'初始图像size: {h},{w}')

    """ 数据预处理 """
    image_data, nw, nh  = resize_image(old_img, (input_shape[1], input_shape[0]))
    image_data  = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)

    return old_img,(h,w),(nw,nh),image_data


def postProcessing():
    """ 对预测结果进行后处理 """
    # resize回图像原始的大小
    pr = cv2.resize(pr_arrays, (w, h), interpolation = cv2.INTER_LINEAR)
    pr = pr.argmax(axis=-1) # 取出每一个像素点的种类
    seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))

    if num_classes <= 21:
        colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 
                        (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 
                        (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), 
                        (128, 64, 12)]
    else:
        hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)]
        colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))


    for c in range(num_classes):
        seg_img[:,:,0] += ((pr[:,: ] == c )*(colors[c][0] )).astype('uint8')
        seg_img[:,:,1] += ((pr[:,: ] == c )*(colors[c][1] )).astype('uint8')
        seg_img[:,:,2] += ((pr[:,: ] == c )*(colors[c][2] )).astype('uint8')

    resultImage = Image.fromarray(np.uint8(seg_img))
    image = Image.blend(old_img,resultImage,0.5)

    return image

def saveAndShow(image):
    savename = os.path.basename(filepath)[:-4]+"httpResult.jpg"
    savePath = 'servingOut/'
    if not os.path.exists(savePath):
        os.mkdir(savePath)

    image.save(savePath+savename)

    plt.title(os.path.basename(filepath))
    plt.imshow(image)
    plt.show()

if __name__ == '__main__':
    mymodel = tf.saved_model.load("test/1")
    while True:
        try:
            filepath = input('请输入待预测图像路径(输入c退出): ')
            if filepath == 'c':
                break        
            old_img,(h,w),(nw,nh),image_data = preProcessing(filepath=filepath)            
            pr = mymodel(image_data)[0]
            pr_arrays = pr.numpy()
            image = postProcessing()
            saveAndShow(image)
        except Exception as e:
            print(e)
            continue
    

是在httpClient.py(参考文章)的基础上改的,主要是导入模型和输入data进行推理:

mymodel = tf.saved_model.load("test/1")
pr = mymodel(image_data)[0]

这个pr目前是tensor类型,需要转成numpy,然后才可以进行后处理

pr_arrays = pr.numpy()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值