PyTorch ResNet实现图像分类(从模型的训练到Android部署)

1.数据集

数据集地址:10 Monkey Species

采用kaggle上的猴子数据集,包含两个文件:训练集和验证集。每个文件夹包含10个标记为n0-n9的猴子。图像尺寸为400x300像素或更大,并且为JPEG格式(近1400张图像)。

在这里插入图片描述

图片样本

在这里插入图片描述

图片类别标签,训练集,验证集划分说明

在这里插入图片描述

2.代码

2.1 定义需要的库
import os
import sys
import json
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
2.2 定义训练验证函数
def train_and_val(epochs, model, train_loader, val_loader, criterion, optimizer):
    torch.cuda.empty_cache()
    train_loss = []
    val_loss = []
    train_acc = []
    val_acc = []
    best_acc = 0

    model.to(device)
    fit_time = time.time()
    for e in range(epochs):
        since = time.time()
        running_loss = 0
        training_acc = 0
        with tqdm(total=len(train_loader)) as pbar:
            for image, label in train_loader:
                # training phase

                #                 images, labels = data
                #             optimizer.zero_grad()
                #             logits = net(images.to(device))
                #             loss = loss_function(logits, labels.to(device))
                #             loss.backward()
                #             optimizer.step()

                model.train()
                optimizer.zero_grad()
                image = image.to(device)
                label = label.to(device)
                # forward
                output = model(image)
                loss = criterion(output, label)
                predict_t = torch.max(output, dim=1)[1]

                # backward
                loss.backward()
                optimizer.step()  # update weight

                running_loss += loss.item()
                training_acc += torch.eq(predict_t, label).sum().item()
                pbar.update(1)

        model.eval()
        val_losses = 0
        validation_acc = 0
        # validation loop
        with torch.no_grad():
            with tqdm(total=len(val_loader)) as pb:
                for image, label in val_loader:
                    image = image.to(device)
                    label = label.to(device)
                    output = model(image)

                    # loss
                    loss = criterion(output, label)
                    predict_v = torch.max(output, dim=1)[1]

                    val_losses += loss.item()
                    validation_acc += torch.eq(predict_v, label).sum().item()
                    pb.update(1)

            # calculatio mean for each batch
            train_loss.append(running_loss / len(train_dataset))
            val_loss.append(val_losses / len(val_dataset))

            train_acc.append(training_acc / len(train_dataset))
            val_acc.append(validation_acc / len(val_dataset))
            
            torch.save(model, "last.pth")
            if best_acc<(validation_acc / len(val_dataset)):
                torch.save(model, "best.pth")
            

            print("Epoch:{}/{}..".format(e + 1, epochs),
                  "Train Acc: {:.3f}..".format(training_acc / len(train_dataset)),
                  "Val Acc: {:.3f}..".format(validation_acc / len(val_dataset)),
                  "Train Loss: {:.3f}..".format(running_loss / len(train_dataset)),
                  "Val Loss: {:.3f}..".format(val_losses / len(val_dataset)),
                  "Time: {:.2f}s".format((time.time() - since)))
            

    history = {'train_loss': train_loss, 'val_loss': val_loss,'train_acc': train_acc, 'val_acc': val_acc}
    print('Total time: {:.2f} m'.format((time.time() - fit_time) / 60))
    
    return history
2.3定义ResNet网络
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=10,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x

def resnet34(num_classes=10, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=10, include_top=True):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=10, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=10, include_top=True):
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=10, include_top=True):
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)
2.4 设置训练集和验证集
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

BATCH_SIZE = 16

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

train_dataset = datasets.ImageFolder("../input/10-monkey-species/training/training/", transform=data_transform["train"])  # 训练集数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                           num_workers=2)  # 加载数据

val_dataset = datasets.ImageFolder("../input/10-monkey-species/validation/validation/", transform=data_transform["val"])  # 测试集数据
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                         num_workers=2)  # 加载数据
2.5 开始训练
net = resnet34()
loss_function = nn.CrossEntropyLoss()  # 设置损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0001)  # 设置优化器和学习率
epoch = 60

history = train_and_val(epoch, net, train_loader, val_loader, loss_function, optimizer)
执行结果
Epoch:55/60.. Train Acc: 0.813.. Val Acc: 0.860.. Train Loss: 0.038.. Val Loss: 0.029.. Time: 38.40s
100%|██████████| 69/69 [00:28<00:00,  2.38it/s]
100%|██████████| 17/17 [00:09<00:00,  1.81it/s]
Epoch:56/60.. Train Acc: 0.830.. Val Acc: 0.882.. Train Loss: 0.031.. Val Loss: 0.025.. Time: 38.84s
100%|██████████| 69/69 [00:27<00:00,  2.48it/s]
100%|██████████| 17/17 [00:09<00:00,  1.78it/s]
Epoch:57/60.. Train Acc: 0.843.. Val Acc: 0.871.. Train Loss: 0.031.. Val Loss: 0.025.. Time: 37.80s
100%|██████████| 69/69 [00:28<00:00,  2.39it/s]
100%|██████████| 17/17 [00:09<00:00,  1.86it/s]
Epoch:58/60.. Train Acc: 0.829.. Val Acc: 0.827.. Train Loss: 0.030.. Val Loss: 0.035.. Time: 38.49s
100%|██████████| 69/69 [00:28<00:00,  2.39it/s]
100%|██████████| 17/17 [00:09<00:00,  1.86it/s]
Epoch:59/60.. Train Acc: 0.852.. Val Acc: 0.853.. Train Loss: 0.029.. Val Loss: 0.031.. Time: 38.42s
100%|██████████| 69/69 [00:28<00:00,  2.39it/s]
100%|██████████| 17/17 [00:08<00:00,  1.90it/s]
Epoch:60/60.. Train Acc: 0.826.. Val Acc: 0.831.. Train Loss: 0.032.. Val Loss: 0.035.. Time: 38.25s
Total time: 38.95 m
2.6 打印准确率以及loss曲线
def plot_loss(x, history):
    plt.plot(x, history['val_loss'], label='val', marker='o')
    plt.plot(x, history['train_loss'], label='train', marker='o')
    plt.title('Loss per epoch')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()


def plot_acc(x, history):
    plt.plot(x, history['train_acc'], label='train_acc', marker='x')
    plt.plot(x, history['val_acc'], label='val_acc', marker='x')
    plt.title('Score per epoch')
    plt.ylabel('score')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()

plot_loss(np.arange(0,epoch), history)
plot_acc(np.arange(0,epoch), history)
执行结果

loss曲线
在这里插入图片描述

24159.png?origin_url=%2Fimg%2FbVb4we&pos_id=img-N4ORxcoi-1724307783733)

准确率曲线
在这里插入图片描述

2.7 查看每一类的准确率
classes = ('n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6', 'n7', 'n8', 'n9')

class_correct = [0.] * 10
class_total = [0.] * 10 
y_test, y_pred = [] , []
X_test = []

with torch.no_grad():
    for images, labels in val_loader:
        X_test.extend([_ for _ in images])
        outputs = model(images.to(device))
        _, predicted = torch.max(outputs, 1)
        predicted = predicted.cpu()
        c = (predicted == labels).squeeze()
        for i, label in enumerate(labels):
            class_correct[label] += c[i].item()
            class_total[label] += 1
        y_pred.extend(predicted.numpy())
        y_test.extend(labels.cpu().numpy())      
        
for i in range(10):
    print(f"Acuracy of {classes[i]:5s}: {100*class_correct[i]/class_total[i]:2.0f}%")
执行结果
Acuracy of n0   : 77%
Acuracy of n1   : 86%
Acuracy of n2   : 85%
Acuracy of n3   : 87%
Acuracy of n4   : 85%
Acuracy of n5   : 89%
Acuracy of n6   : 73%
Acuracy of n7   : 75%
Acuracy of n8   : 89%
Acuracy of n9   : 85%
2.8 查看precision,recall和f1-score
from sklearn.metrics import confusion_matrix, classification_report

ac = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
cr = classification_report(y_test, y_pred, target_names=classes)
print("Accuracy is :",ac) 
print(cr)
执行结果
Accuracy is : 0.8308823529411765
              precision    recall  f1-score   support

          n0       0.77      0.77      0.77        26
          n1       0.69      0.86      0.76        28
          n2       1.00      0.85      0.92        27
          n3       0.93      0.87      0.90        30
          n4       0.88      0.85      0.86        26
          n5       0.81      0.89      0.85        28
          n6       0.90      0.73      0.81        26
          n7       0.84      0.75      0.79        28
          n8       0.89      0.89      0.89        27
          n9       0.71      0.85      0.77        26

    accuracy                           0.83       272
   macro avg       0.84      0.83      0.83       272
weighted avg       0.84      0.83      0.83       272
2.9 查看混淆矩阵
import seaborn as sns, pandas as pd

labels = pd.DataFrame(cm).applymap(lambda v: f"{v}" if v!=0 else f"")
plt.figure(figsize=(7,5))
sns.heatmap(cm, annot=labels, fmt='s', xticklabels=classes, yticklabels=classes, linewidths=0.1 )
plt.show()

在这里插入图片描述

3.模型部署在Android

3.1 导出onnx模型
INPUT_DICT = './weight/best.pth'
OUT_ONNX = './weight/best.onnx'

x = torch.randn(1, 3, 224, 224)
input_names = ["input"]
out_names = ["output"]

model= torch.load(INPUT_DICT, map_location=torch.device('cpu'))
model.eval()

torch.onnx._export(model, x, OUT_ONNX, export_params=True, training=False, input_names=input_names, output_names=out_names)
print('please run: python -m onnxsim test.onnx test_sim.onnx\n')
3.2 将onnx模型简化
python -m onnxsim best.onnx best_sim.onnx

在这里插入图片描述

3.3 使用ncnn进行转化
首先转化为.param和.bin文件
onnx2ncnn.exe best_sim.onnx res.param res.bin

在这里插入图片描述

将.param和.bin文件加密
ncnn2mem.exe res.param res.bin res.id.h res.mem.h

在这里插入图片描述

3.4 最终效果

在这里插入图片描述

测试的时候发现,将图片稍微裁剪一下,猴子区域占整幅图像的比例大一点效果较好。

代码开源(仅供参考)

1.完整训练代码:https://github.com/yaoyi30/ResNet_Image_Classification_PyTorch

2.安卓代码:https://github.com/yaoyi30/ResNet_ncnn_android

3.我的CSDN:姚先生97的博客_CSDN博客

作者:YaoXiansheng
文章来源:知乎

推荐阅读

更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。@TOC

欢迎使用Markdown编辑器

你好! 这是你第一次使用 Markdown编辑器 所展示的欢迎页。如果你想学习如何使用Markdown编辑器, 可以仔细阅读这篇文章,了解一下Markdown的基本语法知识。

新的改变

我们对Markdown编辑器进行了一些功能拓展与语法支持,除了标准的Markdown编辑器功能,我们增加了如下几点新功能,帮助你用它写博客:

  1. 全新的界面设计 ,将会带来全新的写作体验;
  2. 在创作中心设置你喜爱的代码高亮样式,Markdown 将代码片显示选择的高亮样式 进行展示;
  3. 增加了 图片拖拽 功能,你可以将本地的图片直接拖拽到编辑区域直接展示;
  4. 全新的 KaTeX数学公式 语法;
  5. 增加了支持甘特图的mermaid语法1 功能;
  6. 增加了 多屏幕编辑 Markdown文章功能;
  7. 增加了 焦点写作模式、预览模式、简洁写作模式、左右区域同步滚轮设置 等功能,功能按钮位于编辑区域与预览区域中间;
  8. 增加了 检查列表 功能。

功能快捷键

撤销:Ctrl/Command + Z
重做:Ctrl/Command + Y
加粗:Ctrl/Command + B
斜体:Ctrl/Command + I
标题:Ctrl/Command + Shift + H
无序列表:Ctrl/Command + Shift + U
有序列表:Ctrl/Command + Shift + O
检查列表:Ctrl/Command + Shift + C
插入代码:Ctrl/Command + Shift + K
插入链接:Ctrl/Command + Shift + L
插入图片:Ctrl/Command + Shift + G
查找:Ctrl/Command + F
替换:Ctrl/Command + G

合理的创建标题,有助于目录的生成

直接输入1次#,并按下space后,将生成1级标题。
输入2次#,并按下space后,将生成2级标题。
以此类推,我们支持6级标题。有助于使用TOC语法后生成一个完美的目录。

如何改变文本的样式

强调文本 强调文本

加粗文本 加粗文本

标记文本

删除文本

引用文本

H2O is是液体。

210 运算结果是 1024.

插入链接与图片

链接: link.

图片: Alt

带尺寸的图片: Alt

居中的图片: Alt

居中并且带尺寸的图片: Alt

当然,我们为了让用户更加便捷,我们增加了图片拖拽功能。

如何插入一段漂亮的代码片

博客设置页面,选择一款你喜欢的代码片高亮样式,下面展示同样高亮的 代码片.

// An highlighted block
var foo = 'bar';

生成一个适合你的列表

  • 项目
    • 项目
      • 项目
  1. 项目1
  2. 项目2
  3. 项目3
  • 计划任务
  • 完成任务

创建一个表格

一个简单的表格是这么创建的:

项目Value
电脑$1600
手机$12
导管$1

设定内容居中、居左、居右

使用:---------:居中
使用:----------居左
使用----------:居右

第一列第二列第三列
第一列文本居中第二列文本居右第三列文本居左

SmartyPants

SmartyPants将ASCII标点字符转换为“智能”印刷标点HTML实体。例如:

TYPEASCIIHTML
Single backticks'Isn't this fun?'‘Isn’t this fun?’
Quotes"Isn't this fun?"“Isn’t this fun?”
Dashes-- is en-dash, --- is em-dash– is en-dash, — is em-dash

创建一个自定义列表

Markdown
Text-to- HTML conversion tool
Authors
John
Luke

如何创建一个注脚

一个具有注脚的文本。2

注释也是必不可少的

Markdown将文本转换为 HTML

KaTeX数学公式

您可以使用渲染LaTeX数学表达式 KaTeX:

Gamma公式展示 Γ ( n ) = ( n − 1 ) ! ∀ n ∈ N \Gamma(n) = (n-1)!\quad\forall n\in\mathbb N Γ(n)=(n1)!nN 是通过欧拉积分

Γ ( z ) = ∫ 0 ∞ t z − 1 e − t d t   . \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}dt\,. Γ(z)=0tz1etdt.

你可以找到更多关于的信息 LaTeX 数学表达式here.

新的甘特图功能,丰富你的文章

2014-01-07 2014-01-09 2014-01-11 2014-01-13 2014-01-15 2014-01-17 2014-01-19 2014-01-21 已完成 进行中 计划一 计划二 现有任务 Adding GANTT diagram functionality to mermaid
  • 关于 甘特图 语法,参考 这儿,

UML 图表

可以使用UML图表进行渲染。 Mermaid. 例如下面产生的一个序列图:

张三 李四 王五 你好!李四, 最近怎么样? 你最近怎么样,王五? 我很好,谢谢! 我很好,谢谢! 李四想了很长时间, 文字太长了 不适合放在一行. 打量着王五... 很好... 王五, 你怎么样? 张三 李四 王五

这将产生一个流程图。:

链接
长方形
圆角长方形
菱形
  • 关于 Mermaid 语法,参考 这儿,

FLowchart流程图

我们依旧会支持flowchart的流程图:

Created with Raphaël 2.3.0 开始 我的操作 确认? 结束 yes no
  • 关于 Flowchart流程图 语法,参考 这儿.

导出与导入

导出

如果你想尝试使用此编辑器, 你可以在此篇文章任意编辑。当你完成了一篇文章的写作, 在上方工具栏找到 文章导出 ,生成一个.md文件或者.html文件进行本地保存。

导入

如果你想加载一篇你写过的.md文件,在上方工具栏可以选择导入功能进行对应扩展名的文件导入,
继续你的创作。


  1. mermaid语法说明 ↩︎

  2. 注脚的解释 ↩︎

<think>好的,我需要回答用户关于PyTorch Mobile的使用指南和实现方法的问题。首先,我得回想一下PyTorch Mobile的主要功能和特点。记得PyTorch Mobile是PyTorch针对移动端部署优化的版本,支持在iOS和Android上运行模型。 接下来,我应该分步骤说明如何使用。首先可能需要提到模型准备,比如如何将训练好的模型转换为移动端支持的格式。这里应该提到TorchScript,因为PyTorch Mobile依赖TorchScript来序列化模型。用户需要将模型转换成TorchScript格式,可以通过跟踪或脚本化两种方法。例如,使用torch.jit.trace或者torch.jit.script。 然后是移动端的集成部分。对于Android,可能需要通过Gradle添加依赖,加载模型文件,并编写推理代码。对于iOS,可能需要通过CocoaPods添加依赖,同样加载模型并执行推理。这里需要给出具体的代码示例,比如Android中使用org.pytorch的库,iOS中使用LibTorch的框架。 还要提到性能优化,比如量化。PyTorch支持训练后的动态量化和量化感知训练,这可以减小模型大小,提升推理速度。需要引用用户提供的参考内容,比如引用[2]提到的量化方法,动态量化和量化感知训练的区别。这里要说明在移动端使用量化模型的好处,如减少内存占用和加快计算速度。 另外,可能需要注意的事项,比如模型版本兼容性、设备性能差异、内存管理等。比如建议在真实设备上测试,使用轻量级模型结构,或者使用硬件加速API如NNAPI。 最后,生成相关问题,帮助用户深入理解。比如如何优化模型尺寸,量化方法选择,性能分析工具,多平台适配问题,以及实时性要求高的场景应用。 需要确保所有数学表达式用$...$,比如量化公式$W_{int8} = \text{round}(W_{float32}/scale)$。独立公式用$$...$$,如果有的话。同时检查LaTeX语法是否正确,使用中文回答,结构清晰。 可能还需要参考用户提供的引用内容,比如引用[1]提到PyTorch的动态计算图和JIT编译器,这和TorchScript相关,可以联系起来说明性能优化。引用[2]的量化方法,可以具体说明在移动端的应用。 现在整理这些点,分步骤说明,给出代码示例,注意事项,以及相关问题。</think>### PyTorch Mobile 使用指南与实现方式 PyTorch Mobile 是专为移动端设计的轻量级推理框架,支持在 iOS 和 Android 设备上高效运行训练好的模型。以下是具体实现步骤与注意事项: --- #### **一、模型准备与导出** 1. **模型转换** 使用 TorchScript 将 PyTorch 模型序列化为移动端可用的格式: ```python import torch model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True) model.eval() # 方法1:通过跟踪(Tracing)生成 TorchScript example_input = torch.rand(1, 3, 224, 224) traced_model = torch.jit.trace(model, example_input) traced_model.save("resnet18_traced.pt") # 方法2:通过脚本化(Scripting)生成 TorchScript scripted_model = torch.jit.script(model) scripted_model.save("resnet18_scripted.pt") ``` - **跟踪**适用于模型结构由数据流控制的场景(如 CNN)。 - **脚本化**适用于控制流复杂的模型(如 RNN)[^1]。 2. **量化优化(可选)** 通过量化减小模型体积并提升推理速度: ```python # 动态量化(Post-training Dynamic Quantization) quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) torch.jit.save(torch.jit.script(quantized_model), "resnet18_quantized.pt") ``` 量化公式可表示为: $$W_{int8} = \text{round}(W_{float32}/scale)$$ 其中 $scale$ 为缩放因子,用于保留精度[^2]。 --- #### **二、移动端集成** 1. **Android 实现** - **步骤1:添加依赖** 在 `build.gradle` 中引入 PyTorch Mobile: ```gradle dependencies { implementation 'org.pytorch:pytorch_android_lite:1.12.1' implementation 'org.pytorch:pytorch_android_torchvision:1.12.1' } ``` - **步骤2:加载模型并推理** ```java Module module = LiteModuleLoader.load(assetFilePath(this, "resnet18_quantized.pt")); Tensor input = TensorImageUtils.bitmapToFloat32Tensor(bitmap, ImageNormalizationParameters.NORMALIZE_0_1); Tensor output = module.forward(IValue.from(input)).toTensor(); ``` 2. **iOS 实现** - **步骤1:通过 CocoaPods 安装** 在 `Podfile` 中添加: ```ruby pod 'LibTorch-Lite', '~> 1.12.1' ``` - **步骤2:Swift/Objective-C 调用** ```swift let modelPath = Bundle.main.path(forResource: "resnet18_quantized", ofType: "pt") let module = try? TorchModule.loadModel(at: modelPath) let input = TorchTensor.from(imageData) // 图像需预处理为 Tensor let output = module?.forward(with: input) ``` --- #### **三、性能优化建议** 1. **模型压缩** - 使用**量化感知训练**(Quantization-aware Training)减少精度损失[^2]。 - 选择轻量级架构(如 MobileNetV3、EfficientNet-Lite)。 2. **硬件加速** - Android 上启用 NNAPI: ```java PyTorchAndroid.setUseNNAPI(true); ``` - iOS 使用 Core ML 转换工具(需通过 `coremltools` 转换模型)。 3. **内存管理** - 避免频繁加载/释放模型,尽量复用实例。 - 限制输入分辨率,如将图像缩放至 224x224。 --- #### **四、注意事项** 1. **兼容性** - 确保 PyTorch Mobile 版本与训练环境一致。 - 测试不同设备(如低端 Android 机型)的推理速度。 2. **动态形状支持** TorchScript 默认支持固定输入尺寸,若需动态输入,需在导出时指定: ```python traced_model = torch.jit.trace(model, example_input, strict=False) ``` ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值