关于将Pytorch模型部署到安卓移动端方法总结

一、Android Studio环境配置

1.安装包下载问题解决

在Android Studio官网下载编译工具时,会出现无法下载的问题,可右键复制下载链接IDMan中进行下载。

2.安装

安装过程中,需要将Android Virtual Device勾选,否则无法使用虚拟机。

安装启动后,会提示没有SDK,设置代码,直接选择cancel键。

完后,会有专门的SKD组件的安装,但是会有unavailable不可安装的情况出现,可通过创建项目后配置gradle后便可以安装了。

二、项目创建

软件安装后可能出现打不开的情况,可选择以管理员身份启动即可解决问题。

选择New Project

选择喜欢的界面样式即可。

使用语言、SDK根据自行需求进行选择就行。

Build configuration language建议选择Kotlin DSL(build.gradle.kts)[Recommended],否则会出现缺少gradle文件的情况。

创建完后会出现如下项目目录,并不会直接出现app的文件夹,需要手动配置gradle。

按照如下目录gradle/wrapper/gradle-wrapper.properties修改distributionUrl为本地地址。(根据原先的地址下载对应的压缩包)

#Wed May 01 21:02:04 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

更变为
#Wed May 01 21:02:04 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
# 对应的gradle-8.4-bin.zip本地地址即可
distributionUrl=file:///D://Android//gradle-8.4-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

在settings.gradle.kts更换阿里源(直接复制粘贴即可)


pluginManagement {
    repositories {
        maven { url=uri ("https://www.jitpack.io")}
        maven { url=uri ("https://maven.aliyun.com/repository/releases")}
        maven { url=uri ("https://maven.aliyun.com/repository/google")}
        maven { url=uri ("https://maven.aliyun.com/repository/central")}
        maven { url=uri ("https://maven.aliyun.com/repository/gradle-plugin")}
        maven { url=uri ("https://maven.aliyun.com/repository/public")}

        google()
        mavenCentral()
        gradlePluginPortal()
    }
}
dependencyResolutionManagement {
    repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
    repositories {
        maven { url=uri ("https://www.jitpack.io")}
        maven { url=uri ("https://maven.aliyun.com/repository/releases")}
        maven { url=uri ("https://maven.aliyun.com/repository/google")}
        maven { url=uri ("https://maven.aliyun.com/repository/central")}
        maven { url=uri ("https://maven.aliyun.com/repository/gradle-plugin")}
        maven { url=uri ("https://maven.aliyun.com/repository/public")}


        google()
        mavenCentral()
    }
}

rootProject.name = "Helloword"
include(":app")

在build.gradle.kts中点击sync now即可自动配置,稍等即可便可变成app文件夹的形式。

选择Project,变成全部文件的形式。

初始新建项目即刻完成。

三、训练模型权重转化

需将训练好的.pth文件转化为.pt文件

"""
该程序使用的是resnet32网络,用到其他网络可自行更改
保存的权重字典目录如下所示。
      ckpt = {
            'weight': model.state_dict(),
            'epoch': epoch,
            'cfg': opt.model,
            'index': name
        }
"""
from models.resnet_cifar import resnet32  # 确保引用你的正确模型架构
import torch
import torch.nn as nn
# 假设你的ResNet定义在resnet.py文件中
model = resnet32()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 100)  # 修改这里的100为你的类别数

# 加载权重
checkpoint = torch.load('modelleader_best.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['weight'], strict=False)  # 使用strict=False可以忽略不匹配的键

model.eval()
# 将模型转换为TorchScript
example_input = torch.rand(1, 3, 32, 32)  # 修改这里以匹配你的模型输入尺寸
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("model.pt")

四、Pytorch项目搭建工作

在如下目录下创建assets文件,将转化好的模型放在里面即可,切记不可直接创建文件夹,会出现找不到模型问题。

在com/example/myapplication下创建了两个类cifarClassed,MainActivity。

MainActivity类
package com.example.myapplication;

import android.content.Context;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;

import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import androidx.core.content.FileProvider;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

public class MainActivity extends AppCompatActivity {
    private static final int PERMISSION_REQUEST_CODE = 101;

    private static final int REQUEST_IMAGE_CAPTURE = 1;
    private static final int REQUEST_IMAGE_SELECT = 2;
    private ImageView imageView;
    private TextView textView;
    private Module module;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // 检查相机权限
        if (ContextCompat.checkSelfPermission(this, android.Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{android.Manifest.permission.CAMERA}, PERMISSION_REQUEST_CODE);
        }

        imageView = findViewById(R.id.image);
        textView = findViewById(R.id.text);
        ImageView logoImageView = findViewById(R.id.logo);
        logoImageView.setImageResource(R.drawable.logo);


        Button takePhotoButton = findViewById(R.id.button_take_photo);
        Button selectImageButton = findViewById(R.id.button_select_image);

        takePhotoButton.setOnClickListener(v -> dispatchTakePictureIntent());
        selectImageButton.setOnClickListener(v -> dispatchGalleryIntent());

        try {
            module = Module.load(assetFilePath(this, "model.pt"));
        } catch (IOException e) {
            Log.e("PytorchHelloWorld", "Error reading assets", e);
            finish();
        }
    }
    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
        if (requestCode == PERMISSION_REQUEST_CODE) {
            if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
                // 权限被授予
                Log.d("Permissions", "Camera permission granted");
            } else {
                // 权限被拒绝
                Log.d("Permissions", "Camera permission denied");
            }
        }
    }
    private void dispatchTakePictureIntent() {
        Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
        if (takePictureIntent.resolveActivity(getPackageManager()) != null) {
            startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE);
        }
    }

    private void dispatchGalleryIntent() {
        Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
        startActivityForResult(intent, REQUEST_IMAGE_SELECT);
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (resultCode == RESULT_OK && (requestCode == REQUEST_IMAGE_CAPTURE || requestCode == REQUEST_IMAGE_SELECT)) {
            Bitmap imageBitmap = null;
            if (requestCode == REQUEST_IMAGE_CAPTURE) {
                Bundle extras = data.getExtras();
                imageBitmap = (Bitmap) extras.get("data");
            } else if (requestCode == REQUEST_IMAGE_SELECT) {
                try {
                    imageBitmap = MediaStore.Images.Media.getBitmap(this.getContentResolver(), data.getData());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            imageView.setImageBitmap(imageBitmap);
            classifyImage(imageBitmap);
        }
    }

//    private void classifyImage(Bitmap bitmap) {
//        Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
//                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
//        Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
//        float[] scores = outputTensor.getDataAsFloatArray();
//        float maxScore = -Float.MAX_VALUE;
//        int maxScoreIdx = -1;
//        for (int i = 0; i < scores.length; i++) {
//            if (scores[i] > maxScore) {
//                maxScore = scores[i];
//                maxScoreIdx = i;
//            }
//        }
//        textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxScoreIdx]);
//        textView.setVisibility(View.VISIBLE); // 设置 TextView 可见
//    }
//    private void classifyImage(Bitmap bitmap) {
//        // 调整图像大小为 32x32 像素
//        Bitmap resizedBitmap = resizeBitmap(bitmap, 32, 32);
//
//        // 将调整大小后的图像转换为 PyTorch Tensor
//        Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap,
//                new float[]{0.485f, 0.456f, 0.406f}, // 均值 Mean
//                new float[]{0.229f, 0.224f, 0.225f}); // 标准差 Std
//
//        // 推理
//        Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
//        float[] scores = outputTensor.getDataAsFloatArray();
//        float maxScore = -Float.MAX_VALUE;
//        int maxScoreIdx = -1;
//        for (int i = 0; i < scores.length; i++) {
//            if (scores[i] > maxScore) {
//                maxScore = scores[i];
//                maxScoreIdx = i;
//            }
//        }
//        textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxScoreIdx]);
//        textView.setVisibility(View.VISIBLE); // 设置 TextView 可见
//    }
//
    private float[] softmax(float[] scores) {
        float max = Float.NEGATIVE_INFINITY;
        for (float score : scores) {
            if (score > max) max = score;
        }
        float sum = 0.0f;
        float[] exps = new float[scores.length];
        for (int i = 0; i < scores.length; i++) {
            exps[i] = (float) Math.exp(scores[i] - max); // 减去最大值防止指数爆炸
            sum += exps[i];
        }
        for (int i = 0; i < exps.length; i++) {
            exps[i] /= sum; // 归一化
        }
        return exps;
    }

    // 图像分类方法
    private void classifyImage(Bitmap bitmap) {
        // 调整图像大小为 32x32 像素
        Bitmap resizedBitmap = resizeBitmap(bitmap, 32, 32);

        // 将调整大小后的图像转换为 PyTorch Tensor
        Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap,
                new float[]{0.485f, 0.456f, 0.406f}, // 使用训练时相同的均值 Mean
                new float[]{0.229f, 0.224f, 0.225f}); // 使用训练时相同的标准差 Std

        // 推理
        Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
        float[] scores = outputTensor.getDataAsFloatArray();
        // 应用自定义的 Softmax 函数获取概率分布
        float[] probabilities = softmax(scores);
        float maxScore = -Float.MAX_VALUE;
        int maxScoreIdx = -1;
        for (int i = 0; i < probabilities.length; i++) {
            if (probabilities[i] > maxScore) {
                maxScore = probabilities[i];
                maxScoreIdx = i;
            }
        }

        // 更新 UI 必须在主线程中完成
        final int maxIndex = maxScoreIdx;
        final float finalMaxScore = maxScore;
        runOnUiThread(new Runnable() {
            @Override
            public void run() {
                textView.setText("推理结果:" + CifarClassed.IMAGENET_CLASSES[maxIndex] + " (" + String.format("%.2f%%", finalMaxScore * 100) + ")");
                textView.setVisibility(View.VISIBLE); // 设置 TextView 可见
            }
        });
    }

///

    //
    // 方法来调整 Bitmap 的大小
    private Bitmap resizeBitmap(Bitmap originalBitmap, int targetWidth, int targetHeight) {
        return Bitmap.createScaledBitmap(originalBitmap, targetWidth, targetHeight, false);
    }

    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }

}

CifarClassed类
package com.example.myapplication;


public class CifarClassed {
    public static String[] IMAGENET_CLASSES = new String[]{
            "apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle",
            "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel",
            "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock",
            "cloud", "cockroach", "couch", "crab", "crocodile", "cup", "dinosaur",
            "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house",
            "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion", "lizard",
            "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse", "mushroom",
            "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear", "pickup_truck",
            "pine_tree", "plain", "plate", "poppy", "porcupine", "possum", "rabbit", "raccoon",
            "ray", "road", "rocket", "rose", "sea", "seal", "shark", "shrew", "skunk",
            "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower",
            "sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor",
            "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow_tree", "wolf",
            "woman", "worm"
    };
}

页面布局存放在MyApplication\app\src\main\res\layout\activity_main.xml文件中。

<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity"
    android:background="#F0F0F0">

    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_gravity="center_vertical"
        android:orientation="vertical"
        android:gravity="center">

        <ImageView
            android:id="@+id/image"
            android:layout_width="200dp"
            android:layout_height="200dp"
            android:scaleType="centerCrop"
            android:elevation="2dp" />

        <!-- 推理结果显示在图片与按钮之间的空白区域 -->
        <TextView
            android:id="@+id/text"
            android:layout_width="wrap_content"
            android:layout_height="wrap_content"
            android:textSize="24sp"
            android:textColor="#FFF"
            android:gravity="center"
            android:layout_marginTop="16dp"
            android:layout_marginBottom="16dp"
            android:visibility="gone" /> <!-- 初始状态隐藏 -->
    </LinearLayout>

    <!-- 按钮位于屏幕底部 -->
    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:orientation="horizontal"
        android:layout_gravity="bottom"
        android:elevation="4dp">

        <Button
            android:id="@+id/button_take_photo"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="拍照"
            android:backgroundTint="#FF6200EE"
            android:textColor="#FFFFFF"
            android:layout_margin="8dp"
            android:elevation="2dp"
            android:stateListAnimator="@null"/>

        <Button
            android:id="@+id/button_select_image"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="选择图片"
            android:backgroundTint="#FF018786"
            android:textColor="#FFFFFF"
            android:layout_margin="8dp"
            android:elevation="2dp"
            android:stateListAnimator="@null"/>
    </LinearLayout>

    <!-- 调整商标为小圆形并放置在顶部中间 -->
    <!-- 调整商标为小圆形并放置在顶部中间使用 CircleImageView -->
    <de.hdodenhof.circleimageview.CircleImageView
        android:id="@+id/logo"
        android:src="@drawable/logo"
        android:layout_width="50dp"
        android:layout_height="50dp"
        android:layout_gravity="top|center_horizontal"
        android:layout_marginTop="16dp"
        android:elevation="5dp"/>
</FrameLayout>

在MyApplication\app\src\main\res\drawable\circle_shape.xml(自行创建)

<?xml version="1.0" encoding="utf-8"?>
<shape xmlns:android="http://schemas.android.com/apk/res/android"
    android:shape="oval">
    <solid android:color="#FFFFFF"/>  <!-- 修改颜色以匹配你的需求 -->
    <size
        android:width="50dp"
        android:height="50dp"/>  <!-- 定义圆的尺寸,确保它与 ImageView 的尺寸相匹配 -->
</shape>

在MyApplication\app\src\main\res\drawable\rounded_background(自行创建)

<?xml version="1.0" encoding="utf-8"?>
<shape xmlns:android="http://schemas.android.com/apk/res/android">
    <solid android:color="#FFFFFF"/>  <!-- 背景色,半透明黑 -->
    <corners android:radius="10dp"/>  <!-- 圆角的大小 -->
</shape>

在MyApplication\app\src\main\AndroidManifest.xml添加相机与读取照片的权限。

<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools">
    <uses-feature android:name="android.hardware.camera" android:required="true"/>
    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
    <uses-permission android:name="android.permission.CAMERA" />




    <application
        android:allowBackup="true"
        android:dataExtractionRules="@xml/data_extraction_rules"
        android:fullBackupContent="@xml/backup_rules"
        android:icon="@mipmap/ic_launcher"
        android:label="@string/app_name"
        android:roundIcon="@mipmap/ic_launcher_round"
        android:supportsRtl="true"
        android:theme="@style/Theme.MyApplication"
        tools:targetApi="31">
        <activity
            android:name=".MainActivity"
            android:exported="true">
            <intent-filter>
                <action android:name="android.intent.action.MAIN" />
                <category android:name="android.intent.category.LAUNCHER" />
            </intent-filter>
        </activity>
    </application>

</manifest>

app级别build.gradle.kts(MyApplication\app\build.gradle.kts)配置如下。

plugins {
    alias(libs.plugins.androidApplication)
}

android {
    namespace = "com.example.myapplication"
    compileSdk = 34
    sourceSets {
        getByName("main") {
            jniLibs.srcDir("libs")
        }
    }

    packaging {
        resources.excludes.add("META-INF/*")
    }


    defaultConfig {
        applicationId = "com.example.myapplication"
        minSdk = 24
        targetSdk = 34
        versionCode = 1
        versionName = "1.0"
        testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
    }

    buildTypes {
        release {
            isMinifyEnabled = false
            proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro")
        }
    }

    compileOptions {
        sourceCompatibility = JavaVersion.VERSION_1_8
        targetCompatibility = JavaVersion.VERSION_1_8
    }
}


dependencies {
    // 使用 alias 来指定库,确保 libs.aliases.gradle 中已经定义了这些别名
    implementation(libs.appcompat)
    implementation(libs.material)
    implementation(libs.activity)
    implementation(libs.constraintlayout)
    testImplementation(libs.junit)
    androidTestImplementation(libs.ext.junit)
    androidTestImplementation(libs.espresso.core)
    implementation("org.pytorch:pytorch_android:1.12.1")
    implementation("org.pytorch:pytorch_android_torchvision:1.12.1")
    implementation("com.google.android.exoplayer:exoplayer:2.14.1")
    implementation("androidx.localbroadcastmanager:localbroadcastmanager:1.0.0")
    implementation("androidx.activity:activity:1.2.0")
    implementation("androidx.fragment:fragment:1.3.0")
    implementation("de.hdodenhof:circleimageview:3.1.0")




}

这段可解决如下bug。

    packaging {
        resources.excludes.add("META-INF/*")
    }
Caused by: com.android.builder.merge.DuplicateRelativeFileException: 2 files found with path ‘META-INF/androidx.core_core.version’.

手动添加非常麻烦,因为不止一个文件冲突!!!

完成以上步骤再按下Sync Now完成依赖的配置工作,需在编译器中自行选择虚拟设备。

完成后即可在MainActivity.java文件启动项目。

五、APK安装包导出

 点击create创建即可,便可得到apk文件。

六、效果图

  • 21
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
### 回答1: 要将PyTorch模型部署Android设备上,可以使用以下步骤: 1. 将PyTorch模型转换为ONNX格式。可以使用PyTorch官方提供的torch.onnx模块将模型转换为ONNX格式。 2. 使用ONNX Runtime for Android将ONNX模型部署Android设备上。可以使用ONNX Runtime for Android提供的Java API将模型加载到Android应用程序中。 3. 在Android应用程序中使用模型进行推理。可以使用Java API调用模型进行推理,并将结果返回给应用程序。 需要注意的是,在将模型部署Android设备上之前,需要确保模型的大小和计算量适合在移动设备上运行。可以使用模型压缩和量化等技术来减小模型的大小和计算量。 ### 回答2: PyTorch是一个开源的Python机器学习库,它为深度学习提供了强大的支持。PyTorch模型可以在计算机上进行训练和调试,但当我们需要将模型部署到移动设备(如Android)上时,我们需要将PyTorch模型转换并集成到移动应用程序中,这需要一系列的步骤。 首先,我们需要将PyTorch模型转换为TorchScript格式,这是一种在移动设备上运行的地图。使用TorchScript脚本将PyTorch模型序列化为可运行的形式,它可以在没有Python运行时进行部署。我们可以使用以下代码将PyTorch模型转换为TorchScript格式: ``` import torch import torchvision # load the PyTorch model model = torchvision.models.resnet18(pretrained=True) # set the model to evaluation mode model.eval() # trace the model to generate a TorchScript traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) ``` 上面的代码将一个预训练的ResNet模型转换为TorchScript格式,现在我们可以将其保存到文件中以备以后使用: ``` traced_model.save('resnet18_model.pt') ``` 接下来,我们需要将TorchScript模型集成到Android应用程序中。我们可以使用Android Studio提供的Android Neural Networks API(NNAPI)来加速我们的深度学习推理。NNAPI是一个Google开发的Android框架,它提供了一些API,可以加速计算机视觉和自然语言处理应用程序中的神经网络推理。我们可以在Gradle文件中添加以下代码,以添加NNAPI支持: ``` dependencies { implementation 'org.pytorch:pytorch_android:1.7.0' implementation 'org.pytorch:pytorch_android_torchvision:1.7.0' } ``` 然后将TorchScript模型文件复制到Android项目中的`assets`文件夹中。 最后,我们需要编写代码将TorchScript模型加载到我们的应用程序中,并使用它来进行推理。下面是一个简单的Android应用程序,可以使用加载的TorchScript模型对图像进行分类: ```java import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.widget.ImageView; import android.widget.TextView; import androidx.appcompat.app.AppCompatActivity; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; public class MainActivity extends AppCompatActivity { private TextView mResultTextView; private ImageView mImageView; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); mResultTextView = findViewById(R.id.result_text_view); mImageView = findViewById(R.id.image_view); // Load the TorchScript model from the assets folder Module module = Module.load(assetFilePath(this, "resnet18_model.pt")); // Load the image and convert it to a PyTorch Tensor Bitmap bitmap = BitmapFactory.decodeResource(this.getResources(), R.drawable.test_image); float[] mean = new float[]{0.485f, 0.456f, 0.406f}; float[] std = new float[]{0.229f, 0.224f, 0.225f}; Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, mean, std); // Run the input through the model IValue outputTensor = module.forward(IValue.from(inputTensor)); // Get the predicted class index from the output Tensor float[] scores = outputTensor.toTensor().getDataAsFloatArray(); int predictedIndex = -1; float maxScore = 0.0f; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { predictedIndex = i; maxScore = scores[i]; } } // Display the result String[] classNames = {"cat", "dog", "fish", "horse", "spider"}; mResultTextView.setText("Prediction: " + classNames[predictedIndex]); mImageView.setImageBitmap(bitmap); } public static String assetFilePath(Context context, String assetName) { File file = new File(context.getFilesDir(), assetName); try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } catch (IOException e) { e.printStackTrace(); } return null; } } ``` 上面的代码将载入从`assets`文件夹中加载的TorchScript模型,为它准备好图像数据,并将其运行给模型模型返回一个输出张量,我们得到预测的类别。 总之,将PyTorch模型部署Android可以通过转换为TorchScript格式,集成到Android应用程序中,以及编写可以使用它进行推理的代码来实现。厂商和第三方可用工具也可以帮助简化部署过程。 ### 回答3: 在让PyTorch模型部署Android设备之前,你需要确保你的模型可用且现在运行良好。这涉及到以下步骤: 1. 在PyTorch中定义并训练模型 首先在PyTorch中定义并训练模型。你需要训练一个模型,这个模型可以处理你希望在移动设备上使用的数据。你需要确保在训练模型时,使用了适当的数据预处理和清理过程。然后,导出模型以便在Android设备上使用。 2. 将PyTorch模型转换为TorchScript格式 将训练好的PyTorch模型转化成TorchScript格式,这是 PyTorch模型导出方面提供的一种功能强大的框架。你可以使用 torch.jit.load() 函数来加载 TorchScript 模型,并在移动设备上使用它。你可以使用torchscript_builder.py 脚本来转换 PyTorch 模型,这个脚本也可以根据你的需要在运行时执行转换。 3. 集成模型Android应用中: Android应用可以使用自己的Java代码,但也可以使用C++接口以及原生代码。所以,集成模型Android 应用可以使用两种方式: Java 接口和 C++ 接口。 3.1 Java 接口 Java 接口可以用于创建用 Java 编写的 Android 应用程序。以下是使用 Java 接口加载 TorchScript 模型的步骤: - 创建一个 Android 应用程序项目。 - 在 Android Studio 中安装 PyTorch 的 Gradle 插件。 - 将 torch-android 库和 pytorch_android 库添加到项目中的 build.gradle 文件中。 - 在代码中使用 TorchScript 加载模型,并使用该程序的 Android 功能来运行。 3.2 C++ 接口 使用 C++ 接口可以创建用 C++ 编写的 Android 应用程序。以下是使用 C++ 接口加载 TorchScript 模型的步骤: - 创建一个 Android 应用程序项目。 - 编写 C++ 代码来加载 TorchScript 模型。 - 在 Android Studio 中创建一个 Android.mk 文件和 Application.mk 文件。 - 将 C++ 代码编译成共享库,然后将共享库打包到 Android 应用程序 APK 文件中。 - 在代码中使用 TorchScript 加载模型,并调用 C++ 程序的 Android 功能来运行。 以上是部署 PyTorch 模型Android 设备的步骤和过程。在集成模型Android 应用中时,需要注意处理异常和各种错误,确保模型可以在 Android 设备上成功加载。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值