此文作为自己备份用,不做其他解释。
服务器端:
from flask import Flask, request
import json
import numpy as np
import sys
import traceback
import cv2
import tensorflow as tf
import neuralgym as ng
from inpaint_model import InpaintCAModel
checkpoint_dir = 'model_logs/release_places2_256'
output_dir = 'static/'
app = Flask(__name__)
@app.route('/ai_repaint', methods=['POST'])
def ai_repaint():
FLAGS = ng.Config('inpaint.yml')
result = {}
try:
image_file = request.files['image']
image_file.save(output_dir+'tmp_image.png')
mask_file = request.files['mask']
mask_file.save(output_dir+'tmp_mask.png')
image_input_path =output_dir+ 'tmp_image.png'
mask_path = output_dir+'tmp_mask.png'
image_out_path = output_dir+'temp_result.png'
model = InpaintCAModel()
image = cv2.imread(image_input_path)
mask = cv2.imread(mask_path)
# mask = cv2.resize(mask, (0,0), fx=0.5, fy=0.5)
assert image.shape == mask.shape
h, w, _ = image.shape
grid = 8
image = image[:h // grid * grid, :w // grid * grid, :]
mask = mask[:h // grid * grid, :w // grid * grid, :]
print('Shape of image: {}'.format(image.shape))
image = np.expand_dims(image, 0)
mask = np.expand_dims(mask, 0)
input_image = np.concatenate([image, mask], axis=2)
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
with tf.compat.v1.Session(config=sess_config) as sess:
input_image = tf.constant(input_image, dtype=tf.float32)
output = model.build_server_graph(FLAGS, input_image)
output = (output + 1.) * 127.5
output = tf.reverse(output, [-1])
output = tf.saturate_cast(output, tf.uint8)
vars_list = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
assign_ops = []
for var in vars_list:
vname = var.name
from_name = vname
var_value = tf.train.load_variable(checkpoint_dir, from_name)
assign_ops.append(tf.compat.v1.assign(var, var_value))
sess.run(assign_ops)
print('Model loaded.')
temp_result = sess.run(output)
cv2.imwrite(image_out_path, temp_result[0][:, :, ::-1])
result['ret'] = 1
result['msg'] = 'success'
result['result'] = image_out_path
except Exception as e:
print('{} error {}'.format(sys._getframe().f_code.co_name, traceback.format_exc()))
result['ret'] = 0
result['msg'] = e.args[0]
finally:
return json.dumps(result, ensure_ascii=False, default=lambda o: o.__dict__)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5003, debug=False)
客户端:
import os
import requests
http_url = 'http://127.0.0.1:5003'
def ai_repaint(image_input_path,mask_input_path):
files = {}
if not os.path.exists(image_input_path): #如果文件不存,则返回None
return None
if not os.path.exists(mask_input_path): #如果文件不存,则返回None
return None
files['image'] = (os.path.basename(image_input_path), open(image_input_path, 'rb'))
files['mask'] = (os.path.basename(mask_input_path), open(mask_input_path, 'rb'))
response = requests.post(http_url + '/ai_repaint', files=files)
result = response.json()
result['httpcode'] = response.status_code
if 'result' in result:
return result['result']
else:
return None
if __name__ == '__main__':
print(ai_repaint(r'examples/places2/case1_input.png',r'examples/places2/case1_mask.png')) #输入文件内容
其中所需两张图片:
输出一张: