一、引言
在深度学习领域,PyTorch 凭借其动态计算图、简洁易用的 API 以及强大的社区支持,已成为众多开发者的首选框架。从学术研究到工业应用,PyTorch 助力无数深度学习模型的开发与落地,如在图像识别中准确识别各类图像,在自然语言处理里实现智能对话和文本生成 。
然而,深度学习模型的开发与展示并非易事。在开发过程中,可视化工具能帮助开发者更好地理解模型行为、调试模型;在模型完成后,又需要一种有效的方式向他人展示模型的效果。Gradio 和 Streamlit 便是两款在 PyTorch 生态中极具价值的可视化工具,它们能将复杂的深度学习模型转化为直观、交互性强的 Web 应用,让模型的展示和使用变得更加便捷,极大地提升了开发效率与用户体验。接下来,就让我们深入了解这两款工具。
二、Gradio 技术详解
2.1 Gradio 简介与安装
Gradio 是一个基于 Python 的开源库 ,它能让机器学习模型的演示和分享变得前所未有的简单。开发者只需几行代码,就能将任何 Python 函数,无论是简单的文本处理函数,还是复杂的深度学习模型,轻松 “包装” 成一个美观、直观、可交互的 Web 应用。
Gradio 具有诸多显著特点。它极为轻量级且高效,特别适合快速原型开发、模型演示以及 AI 教育等场景,能帮助开发者节省大量时间和精力。同时,Gradio 支持多种输入输出类型,无论是文本、图像、音频、视频,还是 3D 模型、时间序列数据等,都能找到对应的组件,满足各种 AI 模型的需求。它还提供了实时交互体验,用户操作能立即触发模型运行并返回结果,方便快速测试模型性能、收集用户反馈并进行迭代优化 。
要安装 Gradio 非常简单,使用 pip 命令即可:
pip install gradio
2.2 Gradio 核心组件与接口
Gradio 的核心接口类是gr.Interface,它通过指定输入类型和输出类型,帮助用户快速创建任何 Python 函数的交互式演示。gr.Interface类有三个核心参数:
- fn:这是将用户界面(UI)包裹起来的函数,也就是实际执行任务的函数,比如图像分类模型的预测函数。
- inputs:用于输入的 Gradio 组件,其组件数量应与函数中的参数数匹配。可以是文本框("text")、图像上传组件("image")、音频上传组件("audio")等 。
- outputs:用于输出的 Gradio 组件,组件数应与函数的返回值数匹配,如文本输出框("text")、标签输出("label")、图像输出框("image")等。
以一个简单的文本问候函数为例:
import gradio as gr
def greet(name):
return "Hello " + name + "!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()
在上述代码中,greet函数接收一个字符串参数name,返回问候语。gr.Interface将greet函数与输入组件"text"(文本输入框)和输出组件"text"(文本输出框)关联起来,最后通过demo.launch()启动应用,用户在浏览器中就能看到一个简单的交互界面,输入名字后可得到问候语。
Gradio 还提供了众多常见的输入输出组件。输入组件如gr.Textbox(比"text"有更大空间,可添加提示字符串),用法为inputs=gr.Textbox(lines=2, placeholder="Name Here...");gr.Slider用于创建滑动条,可选择一个范围内的数字,如gr.Slider(minimum=0, maximum=100, step=1, label="选择数值") 。输出组件如gr.Label常用于显示分类标签和置信度,gr.Image用于显示图像输出等。
2.3 Gradio 高级功能
Gradio 支持多输入输出功能,当输入和输出比较复杂,有多个参数时,可以通过列表的方式传递参数。例如,下面是一个实现加减乘除功能的计算器示例:
import gradio as gr
def calculator(num1, operation, num2):
if operation == "add":
return num1 + num2
elif operation == "subtract":
return num1 - num2
elif operation == "multiply":
return num1 * num2
elif operation == "divide":
if num2 == 0:
raise gr.Error("Cannot divide by zero!")
return num1 / num2
demo = gr.Interface(calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
examples=[[5, "add", 3],
[4, "divide", 2],
[-4, "multiply", 2.5],
[0, "subtract", 1.2]],
title="Toy Calculator",
description="Here's a sample toy calculator. Enjoy!")
if __name__ == "__main__":
demo.launch()
在这个例子中,calculator函数接收三个参数:两个数字num1、num2和一个操作符operation,返回计算结果。gr.Interface的inputs参数通过列表指定了三个输入组件,分别是数字输入框、单选按钮(用于选择操作符)和数字输入框;outputs参数指定为数字输出。examples参数提供了一些示例输入,方便用户快速体验。
通过设置组件属性,可以定制界面的外观和行为。比如设置界面的标题title、描述description、主题theme等。还可以通过live=True参数实现实时交互,即用户输入变化时,界面立即显示最新结果 。
在界面中展示示例能帮助用户更好地理解如何使用应用。通过examples参数可以传入一个列表,列表中的每个元素是一个输入示例,如上述计算器示例中的examples参数。
以图像分类模型为例,展示 Gradio 在模型推理中的应用:
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
import requests
# 加载模型,机器要能联网,需要下载训练好的公开模型
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet152', pretrained=True).eval()
# 加载1000类标签
labels = []
# labels.txt 可以通过`https://git.io/JJkYN`下载
with open('labels.txt') as f:
for ln in f:
label = ln.rstrip('\n')
labels.append(label)
def predict(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
return {labels[i]: float(prediction[i]) for i in range(1000)}
inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=3)
# share = True 表示可以生成一个url链接,公众通过该url就能体验该功能,有效期72h
gr.Interface(fn=predict, inputs=inputs, outputs=outputs).launch(share=True)
在这段代码中,定义了predict函数用于图像分类预测。加载预训练的 ResNet152 模型后,对输入图像进行预处理,然后通过模型进行预测,返回预测结果(各类别的概率)。gr.Interface将predict函数与图像输入组件和标签输出组件关联,设置outputs的num_top_classes=3表示只显示概率最高的前三个类别。最后通过launch(share=True)生成一个可分享的链接,方便他人体验图像分类模型。
三、Streamlit 技术详解
3.1 Streamlit 简介与安装
Streamlit 是一个开源的 Python 库,它让数据科学家和机器学习工程师能够快速创建漂亮且交互式的 Web 应用程序 ,而无需具备深厚的前端开发知识。Streamlit 的设计理念是让开发者专注于数据逻辑和业务实现,通过简单的 Python 代码就能构建出功能强大的应用,将数据脚本轻松转化为可分享的网页应用 。
使用 Streamlit 有诸多好处。它的语法简洁直观,开发者可以使用熟悉的 Python 语法快速搭建应用,极大地提高了开发效率,比如在开发数据可视化应用时,能快速将数据以图表形式展示出来。Streamlit 支持实时更新,当代码发生变化时,应用程序会自动重新加载,无需手动刷新页面,这使得开发过程更加流畅 。它还提供了丰富的组件,如按钮、滑块、文本输入框等,方便与用户进行交互,增强用户体验 。
安装 Streamlit 很简单,在已经安装 Python 的环境下,使用 pip 命令进行安装:
pip install streamlit
3.2 Streamlit 基础用法
使用 Streamlit 创建应用程序非常简单。下面是一个简单的示例,展示如何创建一个包含文本和按钮的基本应用:
import streamlit as st
# 设置页面标题
st.title('我的第一个Streamlit应用')
# 添加文本
st.write('欢迎使用Streamlit,这是一个简单的示例。')
# 添加按钮
if st.button('点击我'):
st.write('你点击了按钮!')
在上述代码中,首先导入 Streamlit 库并别名为st。使用st.title设置页面的标题,st.write用于显示文本内容。st.button创建了一个按钮,当按钮被点击时,通过条件判断if st.button('点击我'),执行相应的代码块,显示 “你点击了按钮!” 。
Streamlit 还可以方便地显示数据。例如,展示一个 Pandas 数据框:
import streamlit as st
import pandas as pd
import numpy as np
# 创建数据框
data = pd.DataFrame(
np.random.randn(10, 5),
columns=['col1', 'col2', 'col3', 'col4', 'col5']
)
# 显示数据框
st.dataframe(data)
这段代码创建了一个包含 10 行 5 列随机数据的 Pandas 数据框,然后使用st.dataframe将数据框以交互式表格的形式展示在应用中,用户可以对表格进行排序、筛选等操作 。
在 Streamlit 中绘制图表也十分便捷,以绘制折线图为例:
import streamlit as st
import pandas as pd
import numpy as np
# 创建数据
data = pd.DataFrame(
np.random.randn(20, 2),
columns=['col1', 'col2']
)
# 绘制折线图
st.line_chart(data)
代码中生成了一个包含 20 行 2 列随机数据的数据框,通过st.line_chart直接将数据框绘制成折线图展示出来,能直观地呈现数据的变化趋势 。
3.3 Streamlit 进阶功能
Streamlit 的缓存机制是其重要的进阶功能之一。Streamlit 应用在运行时,每次用户交互都会触发整个脚本的重新执行,这可能导致一些耗时操作(如数据加载、复杂计算和模型训练等)被重复执行,严重影响应用响应速度。为了解决这些问题,Streamlit 提供了缓存机制,通过st.cache_data和st.cache_resource装饰器来实现。
st.cache_data用于缓存数据,适用于缓存函数的输出结果,特别是那些返回可序列化数据对象的函数(如 Pandas DataFrame、NumPy 数组、字符串、整数等) 。例如:
import streamlit as st
import requests
import pandas as pd
# 使用 st.cache_data 缓存数据加载
@st.cache_data(ttl=3600) # 缓存1小时
def fetch_data(api_url):
response = requests.get(api_url)
data = response.json()
df = pd.DataFrame(data)
return df
# 用户界面部分
st.title("使用 st.cache_data 缓存数据加载")
api_url = "https://jsonplaceholder.typicode.com/posts"
df = fetch_data(api_url)
st.write(df)
在这个例子中,fetch_data函数被@st.cache_data装饰器修饰,ttl=3600表示缓存的生存时间为 1 小时。第一次调用fetch_data函数时,数据会被加载并缓存,后续调用时如果在 1 小时内,直接从缓存中读取数据,避免重复请求 API,从而提高应用的运行效率 。
st.cache_resource用于缓存资源,适用于缓存那些需要初始化但不需要频繁重新计算的对象,如数据库连接、模型加载等 。比如:
import streamlit as st
import joblib
# 使用 st.cache_resource 缓存模型加载
@st.cache_resource
def load_model(model_path):
model = joblib.load(model_path)
return model
# 用户界面部分
st.title("使用 st.cache_resource 缓存模型加载")
model_path = "path/to/your/model.pkl"
model = load_model(model_path)
st.write("模型已加载,可以进行预测!")
此代码中,load_model函数被@st.cache_resource装饰器修饰,模型加载后会被缓存,后续调用时直接从缓存中读取模型,避免重复加载,节省时间和资源 。
Streamlit 提供了多种布局选项,以满足不同的页面布局需求。分栏布局可以将页面分为多列,使内容展示更加紧凑和有条理 。例如:
import streamlit as st
# 分栏布局
col1, col2 = st.columns(2)
col1.write('这是第一列')
col2.write('这是第二列')
上述代码使用st.columns(2)将页面分为两列,分别在col1和col2中写入不同的内容,实现分栏展示 。
选项卡布局则可以在同一页面中切换不同的内容区域,方便用户查看不同类型的信息 。示例如下:
import streamlit as st
# 选项卡布局
tab1, tab2 = st.tabs(['选项卡1', '选项卡2'])
with tab1:
st.write('内容在选项卡1中')
with tab2:
st.write('内容在选项卡2中')
这段代码创建了两个选项卡tab1和tab2,通过with语句分别在不同的选项卡中添加内容,用户可以通过点击选项卡来切换查看不同的内容 。
Streamlit 拥有丰富的用户交互组件,能增强用户与应用的互动性。比如滑动条组件st.slider,可以让用户在指定范围内选择一个数值 :
import streamlit as st
# 滑动条组件
value = st.slider('选择一个数值', 0, 100, 50)
st.write(f'你选择的数值是 {value}')
在这个例子中,st.slider创建了一个滑动条,用户可以在 0 到 100 的范围内选择一个数值,默认值为 50。选择的数值会赋值给value变量,并通过st.write显示出来 。
下拉选择框组件st.selectbox,用于让用户从预设的选项中选择一个 :
import streamlit as st
# 下拉选择框组件
option = st.selectbox(
'选择一个选项',
['选项A', '选项B', '选项C']
)
st.write(f'你选择了 {option}')
这里st.selectbox创建了一个下拉选择框,用户可以从['选项A', '选项B', '选项C']这三个选项中选择一个,选择的结果会通过st.write展示 。
以构建一个简单的数据可视化应用为例,综合运用上述进阶功能:
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# 使用缓存加载数据
@st.cache_data
def load_data():
data = pd.DataFrame(
np.random.randn(100, 3),
columns=['col1', 'col2', 'col3']
)
return data
# 加载数据
df = load_data()
# 设置页面标题
st.title('数据可视化应用')
# 分栏布局
col1, col2 = st.columns(2)
# 第一列:显示数据框
with col1:
st.write('数据预览')
st.dataframe(df.head())
# 第二列:用户交互组件和图表
with col2:
# 滑动条选择要显示的行数
num_rows = st.slider('选择显示行数', 5, 50, 10)
st.write(f'显示前 {num_rows} 行数据')
st.dataframe(df.head(num_rows))
# 下拉选择框选择要绘制的列
selected_col = st.selectbox(
'选择要绘制的列',
df.columns
)
# 绘制柱状图
fig, ax = plt.subplots()
ax.bar(df.index, df[selected_col])
st.pyplot(fig)
在这个示例中,首先使用@st.cache_data装饰器缓存load_data函数的结果,提高数据加载效率 。通过st.columns进行分栏布局,在第一列中显示数据框的前几行数据,在第二列中添加了滑动条和下拉选择框组件,用户可以通过滑动条选择显示的数据行数,通过下拉选择框选择要绘制图表的列 。最后根据用户选择绘制相应的柱状图,实现了一个简单但功能丰富的数据可视化应用,展示了 Streamlit 在数据处理和可视化方面的强大能力 。
四、Gradio 在 PyTorch 中的应用案例
4.1 图像分类案例
以猫狗分类这一经典的图像分类任务为例,使用 Gradio 构建 PyTorch 图像分类模型的可视化界面,能够直观地展示模型的分类效果,便于用户与模型进行交互。以下是实现该案例的完整步骤和代码:
数据准备
首先,需要准备猫狗分类的数据集。可以从公开的数据集平台(如 Kaggle、ModelScope 等)获取,也可以自行收集整理。假设已经下载好数据集,并按照训练集和测试集的格式进行了划分,且每个类别有对应的文件夹存放图像。
定义数据加载和预处理的代码如下:
import os
import csv
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
class DatasetLoader(Dataset):
def __init__(self, csv_path):
self.csv_file = csv_path
with open(self.csv_file, 'r') as file:
self.data = list(csv.reader(file))
self.current_dir = os.path.dirname(os.path.abspath(__file__))
def preprocess_image(self, image_path):
full_path = os.path.join(self.current_dir, 'datasets', image_path)
image = Image.open(full_path)
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return image_transform(image)
def __getitem__(self, index):
image_path, label = self.data[index]
image = self.preprocess_image(image_path)
return image, int(label)
def __len__(self):
return len(self.data)
batch_size = 8
TrainDataset = DatasetLoader("datasets/train.csv")
ValDataset = DatasetLoader("datasets/val.csv")
TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True)
ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False)
在这段代码中,DatasetLoader类继承自torch.utils.data.Dataset,用于加载数据集。__init__方法读取 CSV 文件中的数据路径和标签,并记录当前目录。preprocess_image方法对图像进行预处理,包括调整大小、转换为张量以及归一化。__getitem__方法根据索引返回预处理后的图像和标签,__len__方法返回数据集的大小。最后,使用DataLoader对训练集和验证集进行封装,设置了batch_size和shuffle参数 。
模型训练
选择经典的卷积神经网络模型,如 ResNet50,进行猫狗分类模型的训练。代码如下:
import torch
import torch.nn as nn
from torchvision.models import resnet50
# 加载预训练的ResNet50模型
model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features
# 修改最后一层全连接层,使其适应猫狗分类任务(2个类别)
model.fc = nn.Linear(num_ftrs, 2)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(10):
running_loss = 0.0
running_corrects = 0
for inputs, labels in TrainDataLoader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(TrainDataLoader.dataset)
epoch_acc = running_corrects.double() / len(TrainDataLoader.dataset)
print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')
上述代码中,首先加载预训练的 ResNet50 模型,然后修改其最后一层全连接层,将输出类别数改为 2,以适应猫狗分类任务 。接着定义了交叉熵损失函数和 Adam 优化器。在训练过程中,将数据加载到 GPU 上(如果可用),进行前向传播、计算损失、反向传播和参数更新 。每个 epoch 结束后,计算并打印当前 epoch 的损失和准确率 。
Gradio 界面搭建
训练好模型后,使用 Gradio 搭建可视化界面,让用户可以上传图像并获得分类结果。代码如下:
import gradio as gr
from PIL import Image
# 定义预测函数
def predict(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])(inp).unsqueeze(0)
model.eval()
with torch.no_grad():
outputs = model(inp.to(device))
_, preds = torch.max(outputs, 1)
return "猫" if preds.item() == 0 else "狗"
# 创建Gradio界面
inputs = gr.inputs.Image()
outputs = gr.outputs.Textbox()
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="猫狗分类",
description="上传一张图片,判断是猫还是狗").launch(share=True)
在这段代码中,predict函数接收用户上传的图像,对其进行预处理后,通过训练好的模型进行预测,并返回分类结果 。gr.Interface创建了一个 Gradio 界面,将predict函数与图像输入组件和文本输出组件关联起来,设置了界面的标题和描述,并通过launch(share=True)启动应用,生成一个可分享的链接,方便他人使用 。
4.2 夜景增强案例
利用 Gradio 部署基于 PyTorch 的夜景增强模型,能够让用户直观地看到夜景图像增强后的效果。以下是相关的实现步骤和代码:
模型结构
以 “Learning to See in the Dark” 论文中的模型为例,其结构类似 U - Net,采用编码 - 解码(Encoder - Decoder)架构,通过卷积层(Conv)、下采样(Pooling)、反卷积(UpConv)和像素重排(Pixel Shuffle)来实现高质量的图像增强 。网络结构代码如下:
import torch
import torch.nn as nn
class SeeInDark(nn.Module):
def __init__(self, num_classes=10):
super(SeeInDark, self).__init__()
self.conv1_1 = nn.Conv2d(4, 32, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.conv10_1 = nn.Conv2d(32, 12, kernel_size=1, stride=1)
def forward(self, x):
conv1 = self.lrelu(self.conv1_1(x))
conv1 = self.lrelu(self.conv1_2(conv1))
pool1 = self.pool1(conv1)
conv2 = self.lrelu(self.conv2_1(pool1))
conv2 = self.lrelu(self.conv2_2(conv2))
pool2 = self.pool1(conv2)
conv3 = self.lrelu(self.conv3_1(pool2))
conv3 = self.lrelu(self.conv3_2(conv3))
pool3 = self.pool1(conv3)
conv4 = self.lrelu(self.conv4_1(pool3))
conv4 = self.lrelu(self.conv4_2(conv4))
pool4 = self.pool1(conv4)
conv5 = self.lrelu(self.conv5_1(pool4))
conv5 = self.lrelu(self.conv5_2(conv5))
up6 = self.upv6(conv5)
up6 = torch.cat([up6, conv4], 1)
conv6 = self.lrelu(self.conv6_1(up6))
conv6 = self.lrelu(self.conv6_2(conv6))
up7 = self.upv7(conv6)
up7 = torch.cat([up7, conv3], 1)
conv7 = self.lrelu(self.conv7_1(up7))
conv7 = self.lrelu(self.conv7_2(conv7))
up8 = self.upv8(conv7)
up8 = torch.cat([up8, conv2], 1)
conv8 = self.lrelu(self.conv8_1(up8))
conv8 = self.lrelu(self.conv8_2(conv8))
up9 = self.upv9(conv8)
up9 = torch.cat([up9, conv1], 1)
conv9 = self.lrelu(self.conv9_1(up9))
conv9 = self.lrelu(self.conv9_2(conv9))
conv10 = self.conv10_1(conv9)
out = nn.functional.pixel_shuffle(conv10, 2)
return out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
if m.bias is not None:
m.bias.data.normal_(0.0, 0.02)
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0.0, 0.02)
def lrelu(self, x):
outt = torch.max(0.2 * x, x)
return outt
在这个模型中,编码部分通过多个 3×3 卷积、Leaky ReLU 激活函数和最大池化来提取图像的层次化特征,逐步降低分辨率并增加通道数 。解码部分采用转置卷积、跳跃连接和卷积来恢复高分辨率特征,通过跳跃连接结合低层细节信息,避免信息丢失 。最后在输出层使用像素重排,将通道信息转换为空间信息,实现超分辨率增强,提升最终图像的清晰度 。
推理过程
加载训练好的模型,并进行推理的代码如下:
import torch
from PIL import Image
import torchvision.transforms as transforms
# 加载模型
model = SeeInDark()
# 假设模型权重保存在'best_model.pth'文件中
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
# 定义图像预处理函数
def preprocess_image(image):
transform = transforms.Compose([
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# 定义推理函数
def enhance_image(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
input_tensor = preprocess_image(inp)
with torch.no_grad():
output = model(input_tensor)
output = output.squeeze(0).permute(1, 2, 0).numpy()
output = (output * 255).clip(0, 255).astype('uint8')
return output
这段代码首先加载SeeInDark模型,并从文件中加载预训练的模型权重 。preprocess_image函数将输入图像转换为张量并添加一个维度,以适应模型的输入要求 。enhance_image函数接收用户上传的图像,进行预处理后,通过模型进行推理,最后对输出结果进行处理,将其转换为合适的图像格式并返回 。
Gradio 界面实现
使用 Gradio 创建夜景增强模型的可视化界面,代码如下:
import gradio as gr
# 创建Gradio界面
inputs = gr.inputs.Image()
outputs = gr.outputs.Image()
gr.Interface(fn=enhance_image, inputs=inputs, outputs=outputs, title="夜景增强",
description="上传一张夜景图片,获取增强后的图像").launch(share=True)
在这段代码中,gr.Interface将enhance_image函数与图像输入组件和图像输出组件关联起来,设置了界面的标题和描述,并通过launch(share=True)启动应用,生成可分享的链接 。用户在界面中上传夜景图像后,能够立即看到增强后的图像效果,直观地展示了夜景增强模型的能力 。
五、Streamlit 在 PyTorch 中的应用案例
5.1 手写体数字识别案例
利用 Streamlit 实现 PyTorch 手写体数字识别模型的可视化应用,能让用户更直观地体验模型的识别能力 。以下是详细的实现步骤和代码:
模型训练与保存
使用经典的卷积神经网络(CNN)结构进行手写体数字识别模型的训练 。代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 32 * 7 * 7)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
for epoch in range(5):
running_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if (i + 1) % 100 == 0:
print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {running_loss / 100:.4f}')
running_loss = 0.0
# 保存模型
torch.save(model.state_dict(), 'mnist_cnn.pth')
在这段代码中,CNN类定义了卷积神经网络结构,包含两个卷积层和两个全连接层 。通过transforms.Compose对 MNIST 数据集中的图像进行预处理,包括转换为张量和归一化 。使用DataLoader加载训练数据集,设置batch_size为 64 并打乱数据 。在训练过程中,使用交叉熵损失函数和 Adam 优化器,每 100 个步骤打印一次损失值 。训练完成后,将模型的状态字典保存到mnist_cnn.pth文件中 。
Streamlit 界面实现
创建 Streamlit 界面,实现手写体数字识别的可视化应用 。代码如下:
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# 定义CNN模型(与训练时相同)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 32 * 7 * 7)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载模型
model = CNN()
model.load_state_dict(torch.load('mnist_cnn.pth'))
model.eval()
# 数据预处理
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 预测函数
def predict(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs.data, 1)
return predicted.item()
# Streamlit界面
st.title('手写体数字识别')
uploaded_file = st.file_uploader("上传手写体数字图片", type=['jpg', 'jpeg', 'png'])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='上传的图片', use_column_width=True)
prediction = predict(image)
st.write(f'预测结果: {prediction}')
在这个 Streamlit 应用中,首先定义了与训练时相同的CNN模型类,并加载保存的模型状态字典 。transform对上传的图像进行预处理,包括调整大小、转换为灰度图、转换为张量和归一化 。predict函数接收预处理后的图像,通过模型进行预测并返回预测结果 。Streamlit 界面使用st.title设置标题,st.file_uploader创建文件上传组件,用户上传图片后,通过st.image展示图片,并调用predict函数进行预测,最后使用st.write显示预测结果 。
5.2 手绘电路转换案例
利用 Streamlit 和 PyTorch 实现手绘电路图像转换为数字原理图,能够为电子工程师和电路设计爱好者提供便捷的工具 。以下是该项目的详细介绍:
项目原理
该项目主要基于计算机视觉和深度学习技术。首先,使用边缘检测算法(如 Canny 边缘检测)提取手绘电路图像中的边缘信息,突出电路线条和元件轮廓 。然后,利用卷积神经网络(CNN)对提取的边缘特征进行学习和分类,识别出不同的电路元件,如电阻、电容、电感、二极管等 。最后,根据识别结果和一定的布局算法,将手绘电路转换为标准的数字原理图 。
实现步骤
- 数据收集与标注:收集大量的手绘电路图像,并对其中的电路元件进行标注,建立训练数据集 。标注信息包括元件的类别、位置和连接关系等 。
- 模型训练:选择合适的 CNN 模型,如 ResNet、VGG 等,对标注好的数据集进行训练 。在训练过程中,通过反向传播算法不断调整模型的参数,使其能够准确地识别手绘电路中的元件 。
- 边缘检测:在 Streamlit 应用中,用户上传手绘电路图像后,首先使用 OpenCV 库中的 Canny 边缘检测算法对图像进行处理,提取边缘信息 。
- 元件识别:将边缘检测后的图像输入到训练好的 PyTorch 模型中,进行元件识别 。模型输出每个元件的类别和位置信息 。
- 原理图生成:根据元件识别结果,按照一定的规则和布局算法,将元件连接起来,生成数字原理图 。可以使用绘图库(如 Matplotlib、Plotly 等)将原理图绘制出来 。
示例代码与效果展示
以下是一个简化的示例代码,展示如何使用 Streamlit 和 PyTorch 进行手绘电路转换的部分功能:
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import cv2
import numpy as np
from PIL import Image
# 定义简单的CNN模型
class CircuitCNN(nn.Module):
def __init__(self):
super(CircuitCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 64 * 64, 128)
self.fc2 = nn.Linear(128, 5) # 假设5种元件类别
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 32 * 64 * 64)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
# 数据预处理
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载模型
model = CircuitCNN()
model.load_state_dict(torch.load('circuit_model.pth'))
model.eval()
# 边缘检测函数
def edge_detection(image):
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
return Image.fromarray(edges)
# 预测函数
def predict(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs.data, 1)
return predicted.item()
# Streamlit界面
st.title('手绘电路转换为数字原理图')
uploaded_file = st.file_uploader("上传手绘电路图片", type=['jpg', 'jpeg', 'png'])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='上传的手绘电路图片', use_column_width=True)
# 边缘检测
edges = edge_detection(image)
st.image(edges, caption='边缘检测后的图像', use_column_width=True)
# 元件识别
prediction = predict(edges)
st.write(f'识别的元件类别: {prediction}')
在这个示例中,定义了一个简单的CircuitCNN模型用于元件识别 。transform对上传的图像进行预处理 。edge_detection函数使用 Canny 算法进行边缘检测 。predict函数将边缘检测后的图像输入模型进行预测 。Streamlit 界面设置了标题和文件上传组件,用户上传图片后,依次展示原始图片、边缘检测后的图片以及元件识别结果 。实际应用中,还需要进一步完善原理图生成部分的代码,以实现完整的手绘电路转换功能 。通过这样的项目,可以让用户方便地将手绘电路转换为数字原理图,提高电路设计和分析的效率 。
六、Gradio 与 Streamlit 对比分析
在了解了 Gradio 和 Streamlit 的技术细节以及它们在 PyTorch 中的应用案例后,对这两款工具进行对比分析,有助于我们在实际项目中根据具体需求做出更合适的选择。下面从功能特点、易用性、性能、适用场景等多个方面进行详细对比。
功能特点
Gradio 专注于机器学习模型的快速部署与展示,提供简洁的 API,能方便地将模型封装为交互式 Web 应用,其核心在于对各类深度学习框架的良好兼容性,支持多种输入输出组件,如文本框、图像、音频等,特别适合展示模型的输入输出关系 。
Streamlit 功能更为丰富和通用,不仅可用于模型展示,还广泛应用于数据可视化、数据分析等领域。它提供了丰富的组件库,如按钮、滑块、图表等,支持使用 Markdown 和 HTML 进行样式定制,能实现高级布局和主题配置,可构建复杂的交互式应用程序 。
易用性
Gradio 极为轻量级,几行代码就能创建界面,开发门槛低,上手容易,适合快速搭建小型应用或演示项目,对于机器学习初学者和快速验证场景非常友好 。
Streamlit 虽然也旨在简化 Web 应用开发,但相比 Gradio,其功能的丰富性导致开发门槛稍高,需要开发者熟悉 Python 和一些前端布局知识,不过一旦掌握,能实现更复杂的场景 。
性能
Gradio 比 Streamlit 更轻量,启动速度更快,在快速展示模型效果、进行简单交互时,能迅速响应 。
Streamlit 启动稍慢,但在运行时表现流畅,适合处理大规模数据和构建大型项目,其缓存机制能有效优化耗时操作,提升应用的整体性能 。
适用场景
Gradio 适合快速展示 AI 模型的输入输出,如在图像生成、文本处理、模型演示和验证等场景中,能快速将模型部署为可供他人体验的 Web 应用 。
Streamlit 更适合构建带有数据可视化和交互性的仪表盘,如数据分析报告、完整的数据科学和机器学习应用程序等,能满足复杂的业务逻辑和交互需求 。
组件支持
Gradio 提供的组件相对有限,主要集中在常见的输入输出组件,如文本框、按钮、文件上传、图片显示等,但其输入 / 输出组合功能强大,不过扩展性较低 。
Streamlit 支持丰富的组件,除了基本组件外,还包括表格、滑块、选项卡、基于 Matplotlib、Plotly 等的图表组件等,扩展性高,能满足多样化的界面设计需求 。
部署难易度
Gradio 提供直接的托管服务,支持快速共享 URL,可通过设置share=True生成临时分享链接,也支持嵌入到网站或通过 API 集成,部署简单便捷 。
Streamlit 需要自行部署到服务器或使用 Streamlit Cloud,部署过程相对复杂,但更适合长期运行的应用,能提供更稳定的服务 。
生态支持
Gradio 官方专注于 AI 相关功能,开箱即用,对于机器学习相关的功能支持较好,适合快速验证和展示 。
Streamlit 社区活跃,拥有许多扩展包和第三方集成,如st_aggrid、st_pages等,开发者可以方便地获取各种资源和解决方案,有助于构建更丰富、功能更强大的应用 。
通过以上对比可以看出,Gradio 和 Streamlit 各有优势。如果是进行简单的机器学习模型演示、快速验证想法或展示模型的输入输出,Gradio 是不错的选择;而如果需要构建复杂的数据可视化应用、数据分析仪表盘或完整的机器学习应用程序,Streamlit 则更能发挥其强大的功能和灵活性 。在实际项目中,应根据具体需求、项目规模、开发时间等因素综合考虑,选择最适合的工具 。
七、总结与展望
Gradio 和 Streamlit 作为两款强大的 Python 可视化工具,在 PyTorch 深度学习模型的可视化和应用部署方面发挥了重要作用 。Gradio 以其轻量级、快速部署和对机器学习模型的友好支持,成为快速展示模型效果、进行模型演示和验证的首选工具,如在图像分类、夜景增强等案例中,能让用户迅速体验模型的功能 。Streamlit 凭借其丰富的组件库、强大的数据可视化能力和灵活的布局选项,更适合构建复杂的数据可视化应用、数据分析仪表盘以及完整的机器学习应用程序,如手写体数字识别、手绘电路转换等案例展示了其在实际项目中的强大功能 。
随着深度学习和数据科学的不断发展,未来可视化工具将朝着更加智能化、个性化和集成化的方向发展 。智能化方面,可能会融入更多的人工智能和机器学习技术,实现自动化的数据分析和可视化推荐 。个性化上,能根据用户的使用习惯和需求,提供定制化的可视化界面和交互方式 。集成化则体现在与更多的工具和平台进行无缝集成,拓展应用场景 。