记录一下使用tensorflow-serving部署图像分割的过程
一、将h5权重文件转成saved_model可以部署的模型
changeH5tosavedModel.py
import tensorflow as tf
from nets.unet import Unet as unet
if __name__ == '__main__':
model = unet((512, 512, 3), 2, 'vgg')
model.load_weights('EP100-loss0.196-valoss0.284.h5')
tf.saved_model.save(model, "test/1")
二、利用docker开启tensorflow serving服务
docker run -p 8501:8501 --mount type=bind,source=E:\projectFiles\standard\unetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving
gpu(目前只在linux下测试了,因为win10似乎安装不能nvidia-docker):
首先安装必要的东西:
docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
然后拉取tensorflow-serving gpu镜像:
docker pull tensorflow/serving:latest-gpu
最后开启模型服务
docker run --gpus all -p 8501:8501 --mount type=bind,source=/home/hbli/pythonFiles/unetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving:latest-gpu
MODEL_NAME是自己定的,target最后的unetV1的名字和MODEL_NAME一致,source是被部署的模型所在的文件夹。其他都一样。
三、客户端进行访问
httpClient.py
""" 图像分割的serving """
import cv2
import numpy as np
import requests
import json
import time
from PIL import Image
import colorsys
import matplotlib.pyplot as plt
import os
def resize_image(image, size):
""" 等比例resize """
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
def preprocess_input(image):
image = image / 127.5 - 1
return image
input_shape = (512,512) # 与训练的时候一致
num_classes = 2 # 类别+1
def preProcessing(filepath):
inputs = cv2.imread(filepath)
old_img = Image.open(filepath)
h,w = inputs.shape[0],inputs.shape[1]
# print(f'初始图像size: {h},{w}')
""" 数据预处理 """
image_data, nw, nh = resize_image(old_img, (input_shape[1], input_shape[0]))
image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)
return old_img,(h,w),(nw,nh),image_data
def mainProcess():
start = time.time()
####--------------------------核心代码----------------------------------------####
""" REST API端口 """
url = 'http://localhost:8501/v1/models/unetV1:predict'
data = json.dumps({'inputs':image_data.tolist()}) # 要求输入的数据是json格式
response = requests.post(url,data=data)
result = json.loads(response.content)
outputs = result['outputs'][0]
output_array = np.array(outputs) # list转numpy数组
####--------------------------核心代码---------------------------------------####
print(f'花费时间:{time.time()-start:.2f}s')
# print(type(output_array))
return output_array
def postProcessing():
""" 对预测结果进行后处理 """
# resize回图像原始的大小
pr = cv2.resize(output_array, (w, h), interpolation = cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1) # 取出每一个像素点的种类
seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
if num_classes <= 21:
colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
(128, 64, 12)]
else:
hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)]
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
for c in range(num_classes):
seg_img[:,:,0] += ((pr[:,: ] == c )*(colors[c][0] )).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == c )*(colors[c][1] )).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == c )*(colors[c][2] )).astype('uint8')
resultImage = Image.fromarray(np.uint8(seg_img))
image = Image.blend(old_img,resultImage,0.7)
return image
def saveAndShow(image):
savename = os.path.basename(filepath)[:-4]+"httpResult.jpg"
savePath = 'servingOut/'
if not os.path.exists(savePath):
os.mkdir(savePath)
image.save(savePath+savename)
plt.title(os.path.basename(filepath))
plt.imshow(image)
plt.show()
if __name__ == '__main__':
while True:
try:
filepath = input('请输入待预测图像路径(输入c退出): ')
if filepath == 'c':
break
old_img,(h,w),(nw,nh),image_data = preProcessing(filepath=filepath)
output_array = mainProcess()
image = postProcessing()
saveAndShow(image)
except Exception as e:
print(e)
continue