利用flask建立pytorch 服务

pytorch服务不像tf那样可以进行直接部署,需要自己搭建服务才行

我这个代码是拿自己之前手动部署tf基础上改的,所以上面有tf的痕迹

服务端:

from flask import Flask, render_template, request, redirect, url_for, flash, jsonify
#from tensorflow.keras.models import load_model
import cv2
import numpy as np
#import tensorflow as tf
import uuid
import numpy as np
# load model
def get_model():
    global graph
    graph = tf.compat.v1.get_default_graph()
    with graph.as_default():
        global model
        model = load_model('model.h5')

app = Flask(__name__)
#get_model()
"""
@app.route('/predict', methods=['POST'])
def upload_image():
    with graph.as_default():
        if request.method == 'POST':
            #print(request.get_json())
            response = {'success': True}
            print(response)
            # Prepare image for model
            get_image=request.get_json()['input_image']
            #print(get_image)
            #get_image_tensor=tf.convert_to_tensor(get_image)
            #resized = tf.image.resize(get_image_tensor, [256, 256])
            #print(get_image_tensor)
            #normalized = (get_image_tensor / 127.5) - 1
            #print(normalized)
            #batched = tf.expand_dims(get_image_tensor, 0)
            #print(batched)
            # Apply model!
            model.summary()
            predicted = model.predict(get_image,steps=1)
            print(predicted)
            predicted = predicted[0,:,:,:]
            print(predicted)
            unnormalized = (predicted + 1) / 2

            # Save result on disk to eventually show to user
            result_name = 'static/tmp/{}.jpg'.format(uuid.uuid4())
            cv2.imwrite(result_name, unnormalized * 255)

            response['processed_image'] = result_name
            return jsonify(response)
"""

import torch

new_model = torch.load('./checkpoints/coconew/latest_net_G.pth')
new_model.eval()
print(new_model)
import numpy
@app.route('/predict', methods=['POST'])
def upload_image():
    if request.method == 'POST':
        #print(request.get_json()['input_image'])
        response = {'success': True}
        print(response)
        # Prepare image for model
        get_image=request.get_json()['input_image']
        #print(get_image)
        #get_image_tensor=tf.convert_to_tensor(get_image)
        #resized = tf.image.resize(get_image_tensor, [256, 256])
        #print(get_image_tensor)
        #normalized = (get_image_tensor / 127.5) - 1
        #print(normalized)
        #batched = tf.expand_dims(get_image_tensor, 0)
        #print(batched)
        # Apply model!
        #new_model.summary()
        get_image=torch.from_numpy(numpy.array(get_image))
        predicted =  new_model(get_image)
        print(predicted)
        predicted = predicted[0,:,:,:]
        unnormalized = (predicted + 1) / 2
        # Save result on disk to eventually show to user
        result_name = 'static/tmp/{}.jpg'.format(uuid.uuid4())
        cv2.imwrite(result_name, unnormalized * 255)
        response['processed_image'] = result_name
        return jsonify(response)



# flask will check if raised exception is of type 'SomeException' (or lower)
# if so, will just execute this method
@app.errorhandler(ValueError)
def handle_error(error):
    print("完蛋")
    response = jsonify({'message': str(error)})
    return response

if __name__ == "__main__":
    app.run(debug=True)

客户端

from __future__ import print_function
#import tensorflow as tf
import base64
import requests
import cv2
# model with the name "pix2pix" and using the predict interface.
SERVER_URL = 'http://localhost:5000/predict'
import numpy as np

#from __future__ import absolute_import, division, print_function, unicode_literals
import json
import tensorflow as tf
import  numpy as np
import os
import time

from matplotlib import pyplot as plt
from IPython import display
#import tensorflow as tf
#import tensorflow.contrib.eager as tfe
#tf.enable_eager_execution()

#import tensorflow.contrib.eager as tfe
#tfe.enable_eager_execution()

#import tensorflow as tf
#tf.enable_eager_execution()

BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

##定义加载图片的函数
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    w = tf.shape(image)[1]
    w = w // 2
    real_image = image[:,:w, :]
    input_image = image[:, w:, :]

    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)
    return input_image, real_image
##定义resize函数
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, real_image
##定义随机剪裁函数
def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]

# normalizing the images to [-1, 1]

def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image


@tf.function()
def random_jitter(input_image, real_image):
    """先把size设置为286*286,然后进行随机剪裁"""
  # resizing to 286 x 286 x 3
    input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
    # random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

##加载训练数据
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image
##加载测试数据
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image




PATH = os.path.join(os.path.dirname('/opt/AI/facades/'), '/opt/AI/facades/')
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)


test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)
import numpy
a=numpy.ones([184,256,256])
print(a)
num=1
i=0
for inp, tar in test_dataset.take(num):
    payload = {
  "instances": [{'input_1': inp.numpy()[0].tolist()}]
    }
    #predict_request='{"input_image":%s}' %inp.numpy().tolist()
    predict_request={"input_image":inp.numpy().tolist()}
    #predict_request={"input_image":a.tolist()}
    print(inp.numpy()[0])
    #print(predict_request)
    #json_request = '{{ "instances" : {} }}'.format(np.array2string(inp.numpy(), separator=',')
    #print(predict_request)
    response=requests.post(SERVER_URL, json=predict_request)
    response.raise_for_status()
    print(response)
    prediction = response.json()['predictions']
    print(prediction)
    tmp=np.ones((256, 256,3))
    prediction=127.5*(prediction+tmp)
    print(prediction)
   # prediction=prediction.astype(int)
    #print(prediction)
    cv2.imwrite('./md'+str(i)+'.jpg',prediction[0])
    i=i+1

这里可以自己写客户端,这个里面是直接使用facades里面的数据集,就没改了

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值