前期准备
- 模型 model.pt
- 一张待识别的图片
- 标签 ImageNetClasses.java(具体代码放在了后面)
目录结构如下
代码
下面的代码,将一张图片喂给神经网络模型,得到输出结果后,将结果显示到界面上。
- xml
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ImageView
android:id="@+id/imageView"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:scaleType="fitCenter"/>
<TextView
android:id="@+id/textView"
android:layout_width="match_parent"
android:layout_height="wrap_content"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent" />
</androidx.constraintlayout.widget.ConstraintLayout>
java
public class MainActivity extends AppCompatActivity {
private ImageView imageView;
private TextView textView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
imageView = findViewById(R.id.imageView);
textView = findViewById(R.id.textView);
Bitmap bitmap= null;
Module module = null;
try {
// 1. 获取图片
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
imageView.setImageBitmap(bitmap);
// 2. 加载模型
module =Module.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
e.printStackTrace();
finish();
}
// 3. bitmap -> Tensor
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
// 4. 运行模型
Tensor resultTensor = module.forward(IValue.from(inputTensor)).toTensor();