基于TensorFlow2.3.0的果蔬识别系统的设计
一、开发环境
- Windows 10
- Python 3.7.3
- TensorFlow 2.3.0
- Anaconda 4.12.0
- CUDA 10.1
- cuDNN 7.6.5
二、步骤
2.1 创建一个python 3.7.3的虚拟环境
conda create -n vegetable python==3.7.3
2.2 激活名称为vegetable的虚拟环境
conda activate vegetable
2.3 安装tensorflow-cpu,
pip install tensorflow-cpu==2.3.0
2.4 推荐安装tensorflow-gpu版本,电脑需提前安装好CUDA 10.1和cuDNN 7.6.5
pip install tensorflow-gpu==2.3.0
2.5 准备好果蔬分类数据集
2.6 编写训练模型代码,使用MobileNetV2来训练模型。
# 模型加载
def model_load(IMG_SHAPE=(224, 224, 3), class_num=214):
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
base_model.trainable = False
model = tf.keras.models.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(class_num, activation='softmax')
])
# 输出模型信息
model.summary()
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy'])
return model
# 训练模型
def train(epochs):
# 1. 加载数据集
train_dataset, validate_dataset, class_names = data_load("../data", 224, 224, 16)
print(class_names)
# print('类别的个数-->')
# print(len(class_names))
# 2. 加载模型
model = model_load(class_num=len(class_names))
# 3. 训练
history = model.fit(train_dataset, validation_data=validate_dataset, epochs=epochs)
# 4. 保存模型
model.save("models/vegetable_model.h5")
# 5. 转换为tflite模型
h5_model = tf.keras.models.load_model("models/vegetable_model.h5")
converter = tf.lite.TFLiteConverter.from_keras_model(h5_model)
tflite_model = converter.convert()
open("models/model.tflite", "wb").write(tflite_model)
if __name__ == '__main__':
train(epochs=30)
2.7 训练完成后在models文件夹中得到名称为model.tflite的模型文件,接下来将这个模型文件导入Android Studio工程中。
三、编写Android APP
3.1 打开Android Studio,将model.tflite模型文件拷贝到Android工程的assets文件中
3.2 同时要在app下build.gradle文件添加如下内容
aaptOptions {
noCompress "tflite"
}
3.3 编写activity_main.xml布局文件
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_marginTop="0dp"
android:orientation="vertical">
<TextView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center"
android:text="基于TensorFlow的果蔬识别系统"
android:textColor="@color/black"
android:textSize="25sp" />
<androidx.cardview.widget.CardView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
app:cardCornerRadius="20dp">
<ImageView
android:id="@+id/iv_vegetable"
android:layout_width="200dp"
android:layout_height="200dp"
android:scaleType="centerCrop"
android:src="@drawable/orange" />
</androidx.cardview.widget.CardView>
<androidx.cardview.widget.CardView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="10dp"
app:cardCornerRadius="20dp">
<ScrollView
android:layout_width="wrap_content"
android:layout_height="320dp"
android:layout_margin="10dp"
android:layout_marginTop="10dp">
<TextView
android:id="@+id/tv_vegetable_detail"
android:layout_width="360dp"
android:layout_height="wrap_content"
android:text="@string/orange"
android:textColor="@color/black"
android:textSize="18sp" />
</ScrollView>
</androidx.cardview.widget.CardView>
<Button
android:id="@+id/choose_image"
android:layout_width="230dp"
android:layout_height="50dp"
android:layout_gravity="center"
android:layout_marginTop="50dp"
android:background="@drawable/angle_button"
android:onClick="choose_image"
android:text="选择图片"
android:textColor="@android:color/white"
android:textSize="20sp" />
</LinearLayout>
3.4 编写MainActivity.java代码
private Interpreter interpreter;
private Bitmap bitmap;
private ImageView iv_vegetable;
private TextView tv_vegetable_detail; // 果蔬的介绍
private String[] neededPermissions = new String[]{
Manifest.permission.READ_PHONE_STATE
};
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) {
Window window = this.getWindow();
window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
window.getDecorView().setSystemUiVisibility(View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
| View.SYSTEM_UI_FLAG_LAYOUT_STABLE);
window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
window.setStatusBarColor(Color.GRAY);
}
setContentView(R.layout.activity_main);
/*
* 在选择图片的时候,在android 7.0及以上通过FileProvider获取Uri,不需要文件权限
*/
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) {
List<String> permissionList = new ArrayList<>(Arrays.asList(neededPermissions));
permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE);
neededPermissions = permissionList.toArray(new String[0]);
}
initView();
TFLiteLoader loader = TFLiteLoader.newInstance(this);
interpreter = loader.get();
showToast("模型加载成功!");
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.orange);
}
private void initView() {
tv_vegetable_detail = findViewById(R.id.tv_vegetable_detail);
iv_vegetable = findViewById(R.id.iv_vegetable);
}
private void showToast(String text) {
Toast.makeText(this, text, Toast.LENGTH_LONG).show();
}
// 更换图片
public void choose_image(View view) {
Intent intent = new Intent(Intent.ACTION_PICK);
intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
startActivityForResult(intent, 0);
}
private int maxIndex = 0;
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (data == null || data.getData() == null) {
showToast("获取图片失败");
return;
}
try {
Bitmap src = MediaStore.Images.Media.getBitmap(getContentResolver(), data.getData());
bitmap = Bitmap.createScaledBitmap(src, 224, 224, false);
} catch (IOException e) {
e.printStackTrace();
}
// 识别图片
detect_image();
// 更新显示的图片
iv_vegetable.setImageBitmap(bitmap);
// 更新果蔬的介绍
tv_vegetable_detail.setText(vegetable_detail[maxIndex]);
}
// 识别图片
public void detect_image() {
// bitmap convert to array
float[][][][] pixels = getScaledMatrix(bitmap, input);
interpreter.run(pixels, output);
for (int j = 0; j < output[0].length; j++) {
BigDecimal b = new BigDecimal(output[0][j]);
float f1 = b.setScale(3, BigDecimal.ROUND_HALF_UP).floatValue();
Log.i("Test", f1 + "--> "+ j);
}
float max = output[0][0];
for(int i = 1; i < output[0].length;i++){
if(max < output[0][i]){
max = output[0][i];
maxIndex = i;
}
}
String text = class_names[maxIndex];
// 显示Toast
showToast(text);
}
3.5 安装到手机后的识别效果
基于TensorFlow2.3.0的果蔬识别系统的设计
四、资料下载
APK下载:https://wwi.lanzoup.com/inXuX0adi9de
完整源码下载:https://item.taobao.com/item.htm?ft=t&id=682478970844