1、Flask框架简介
Flask是一个使用Python编写的轻量级Web应用框架,可扩展性很强,相较于Django框架,灵活度很高,开发成本底。它仅仅实现了Web应用的核心功能,Flask由两个主要依赖组成,提供路由、调试、Web服务器网关接口的Werkzeug 实现的和模板语言依赖的jinja2,其他的一切都可以由第三方库来完成。
2、Flask框架安装
在使用Flask之前需要安装一下,安装Flask非常简单只需要在在命令行输入
pip install flask即可
3、Flask实现 Hello World案例
# 导入 Flask 类
from flask import Flask
# 创建了这个类的实例。第一个参数是应用模块或者包的名称。
app = Flask(__name__)
# 使用 route() 装饰器来告诉 Flask 触发函数的 URL
@app.route("/")
def hello():
return "Hello World!"
if __name__ == "__main__":
# 使用 run() 函数来运行本地服务器和我们的应用
app.run()
4、Flask深度学习模型部署
本文通过使用轻量级的WEB框架Flask来实现Python在服务端的部署CIFAR-10的图像分类。效果如下:
CIFAR-10是一个小型图像分类数据集,数据格式类似于MNIST手写数字数据集,在CIFAR-10数据中图片共有10个类别,分别为airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck。
4.1 数据加载
对于CIFAR-10分类任务,PyTorch里的torchvision库提供了专门数据处理函数torchvision.datasets.CIFAR10,构建DataLoader代码如下:
import torchvision
from torchvision import transforms
import torch
from config import data_folder, batch_size
def create_dataset(data_folder, transform_train=None, transform_test=None):
if transform_train is None:
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)
)
]
)
if transform_test is None:
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)
)
]
)
trainset = torchvision.datasets.CIFAR10(
root=data_folder, train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root=data_folder, train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
return trainloader, testloader
4.2 构建模型resent18实现分类
from torch import nn
import torch.nn.functional as F
# 定义残差块ResBlock
class ResBlock(nn.Module):
def __init__(self, inchannel, outchannel, stride=1):
super(ResBlock, self).__init__()
# 这里定义了残差块内连续的2个卷积层
self.left = nn.Sequential(
nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True),
nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(outchannel)
)
self.shortcut = nn.Sequential()
if stride != 1 or inchannel != outchannel:
# s