目录
(2)、初始化Flask app ,创建一个新的Flask应用程序实例
(4)、定义图片预处理函数,对客户端传入的文件进行处理,使得处理后的图片数据可以传入模型进行分类
简介
优点:
轻量级:Flask是一个轻量级的框架,代码量少,灵活性高,适合快速开发小型应用程序。 简单易学:Flask的设计理念简洁明了,入门相对容易,对于初学者来说非常友好。 可扩展性强:Flask提供了丰富的扩展库,开发者可以根据需求选择合适的扩展来扩展功能。 社区支持良好:Flask有一个庞大的社区,提供了丰富的资源和支持。
缺点:
功能相对较少:相比于一些大型框架如Django,Flask的功能相对较少,需要依赖扩展库来实现一些功能。 安全性考虑:由于Flask的轻量级特性,安全性方面的考虑需要开发者自行关注。 不适合大型应用:由于Flask的轻量级特性,它可能不适合开发大型复杂的应用程序。 综上所述,Django、Pyramid和Flask各有其优缺点,选择哪个框架取决于项目的具体需求、开发者的偏好和经验水平。
flask库的安装
flask可以通过命令提示符按照指令 pip install flask 来安装
实践
项目介绍:通过flask框架构建一个服务器,当客户端向服务器发送一张花的图片时服务器会返回客户端关于图片中花的信息。
1、构建服务端
(1)、导入需要的库
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models, datasets
(2)、初始化Flask app ,创建一个新的Flask应用程序实例
# 初始化Flask app
app = flask.Flask(__name__)# 创建一个新的Flask应用程序实例
# __name__参数通常被传递给FasK应用程序来定位应用程序的根路径,这样FlasK就可以知道在哪里找到模板、静态文件等。
# 总体来说app = flask.Flask(__name_)是FLaSK应用程序的起点。它初始化了一个新的FLaSK应用程序实例。为后续添加路由、配置
model = None
use_gpu = False
(3)、定义残差网络模型函数,用于对传入的图片进行分类
def load_model():
global model
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
checkpoint = torch.load('best.pth')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
if use_gpu:
model.cuda()
(4)、定义图片预处理函数,对客户端传入的文件进行处理,使得处理后的图片数据可以传入模型进行分类
def prepare_image(image, target_size):
if image.mode != "RGB":
image = image.convert("RGB")
# Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
image = transforms.Resize(target_size)(image)
image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# Add batch_size axis 增加一个维度,用于按batch测试本次这里一次测试- 张
image = image[None]
if use_gpu:
image = image.cuda()
return torch.tensor(image)
(5)、定义预测函数
注意:其中@app.route("/predict", methods=["POST"])定义了一个 URL 路径为/predict
。这意味着当用户在浏览器中访问这个路径时,Flask 会将该请求路由到下面被装饰的函数进行处理。
@app.route("/predict", methods=["POST"])
def predict():
# 做一个标志,刚开始无图像传入时为false,传入图像时为true
data = {'success': False}
if flask.request.method == 'POST':# 如果收到POST请求
if flask.request.files.get("image"):# 判断是否为图像
image = flask.request.files["image"].read()# 将收到的图像进行读取,内容为二进制
image = Image.open(io.BytesIO(image))
image = prepare_image(image, target_size=(224, 224))
preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
data['predictions'] = list()
for prob, label in zip(results[0][0], results[1][0]):
r = {'label': str(label), "probability": float(prob)}
data['predictions'].append(r)
data["success"] = True
return flask.jsonify(data) # 将最后结果以json格式文件传出
(6)、设置地址端口,开启服务
if __name__ == '__main__':
load_model()
app.run(host='192.168.24.66', port='5012')
完整代码
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models, datasets
# 初始化Flask app
app = flask.Flask(__name__)# 创建一个新的Flask应用程序实例
# __name__参数通常被传递给FasK应用程序来定位应用程序的根路径,这样FlasK就可以知道在哪里找到模板、静态文件等。
# 总体来说app = flask.Flask(__name_)是FLaSK应用程序的起点。它初始化了一个新的FLaSK应用程序实例。为后续添加路由、配置
model = None
use_gpu = False
def load_model():
global model
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
checkpoint = torch.load('best.pth')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
if use_gpu:
model.cuda()
def prepare_image(image, target_size):
if image.mode != "RGB":
image = image.convert("RGB")
# Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
image = transforms.Resize(target_size)(image)
image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# Add batch_size axis 增加一个维度,用于按batch测试本次这里一次测试- 张
image = image[None]
if use_gpu:
image = image.cuda()
return torch.tensor(image)
@app.route("/predict", methods=["POST"])
def predict():
# 做一个标志,刚开始无图像传入时为false,传入图像时为true
data = {'success': False}
if flask.request.method == 'POST':# 如果收到POST请求
if flask.request.files.get("image"):# 判断是否为图像
image = flask.request.files["image"].read()# 将收到的图像进行读取,内容为二进制
image = Image.open(io.BytesIO(image))
image = prepare_image(image, target_size=(224, 224))
preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
data['predictions'] = list()
for prob, label in zip(results[0][0], results[1][0]):
r = {'label': str(label), "probability": float(prob)}
data['predictions'].append(r)
data["success"] = True
return flask.jsonify(data) # 将最后结果以json格式文件传出
if __name__ == '__main__':
load_model()
app.run(host='192.168.24.66', port='5012')
2、 构建客户端
import requests
flask_url = 'http://192.168.24.66:5012/predict' # 此为服务端的地址
def predict_result(image_path):
image = open(image_path, 'rb').read()
payload = {'image': image}
r = requests.post(flask_url, files=payload).json()
if r['success']:
for (i, result) in enumerate(r['predictions']):
print('{}.预测类别为{}:的概率:{}'.format(i + 1, result['label'], result['probability']))
else:
print('Request failed')
if __name__ == "__main__":
predict_result('image_06970.jpg')
运行客户端之前需要先开启服务端,让服务端进入监听状态,运行客户端代码就会将图片按照地址发送给服务端,等待服务端处理过后会返回客户端信息。
结果展示: