有很多没写过代码的小伙伴问我,能不能一步一步地教会他弄一个基于神经网络的图像识别UI,我说你滚吧,我写的代码很烂,也教不会你,后来他再三祈求,我就写了这么一篇垃圾文章,代码我会打上注释,具体自己理解吧。对了,模型是pytorch。
首先呢,确保你电脑安装了python并且to path。创建一个AIGC文件夹,我放在D盘根目录。以后可能各种乱七八糟的东西都往里面丢,然后再弄一个AI torch文件夹,后面的结构我放在这了。
D:\AIGC\AI torch\
│
├── app.py
├── predict.py
├── requirements.txt
└── templates\
└── index.html
第一步,创建文件夹
在 AI torch 文件夹中,创建以下文件:
app.py
predict.py
requirements.txt
然后创建 templates 文件夹,在 templates 文件夹中创建 index.html 文件。
这些文件你就创建一个文本文档改后缀就行,修改内容的时候选择用文本文档打开就行。
第二步,搞代码
我们将这段代码丢进predict.py。
# predict.py
import torch
from torchvision import models, transforms
from PIL import Image
import json
import urllib.request
# 下载并加载 ResNet50 预训练模型
model = models.resnet50(pretrained=True)
model.eval() # 设为评估模式
# 定义图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ImageNet 标签(类名)下载 URL
imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
# 加载 ImageNet 类标签
with urllib.request.urlopen(imagenet_labels_url) as url:
imagenet_labels = json.loads(url.read().decode())
def predict(image_path):
# 打开图片并应用预处理
input_image = Image.open(image_path).convert("RGB") # 确保图像是 RGB 模式
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # 创建 batch 尺寸
# 确保模型在没有 GPU 的情况下工作
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# 获取预测的标签
_, predicted_idx = torch.max(output, 1)
predicted_label = imagenet_labels[predicted_idx.item()]
return predicted_label
这段丢进app.py。
# app.py
from flask import Flask, request, jsonify, render_template
import os
from werkzeug.utils import secure_filename
from predict import predict
app = Flask(__name__)
# 设置上传文件夹
UPLOAD_FOLDER = 'uploads'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def upload_file():
if 'file' not in request.files:
return jsonify({'error': 'No file part'})
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'})
if file:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
label = predict(filepath)
return jsonify({'prediction': label})
if __name__ == '__main__':
app.run(debug=True)
这段丢进index.html。
<!-- index.html -->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>物体识别</title>
</head>
<body>
<h1>上传图片进行物体识别</h1>
<form action="/predict" method="post" enctype="multipart/form-data">
<input type="file" name="file" accept="image/*" required>
<input type="submit" value="上传">
</form>
</body>
</html>
这段是requirements.txt,这个文件是集成了要安装的东西,方便搭环境。
# requirements.txt
torch
torchvision
flask
Pillow
第三步,开干
win+R键,打开命令提示符。
输入以下内容然后回车,这是切换目录。
cd /d "D:\AIGC\AI torch"
然后输入,创建虚拟环境。
python -m venv venv
接着激活环境,这下就进入venv中了。
venv\Scripts\activate
第四步,即将
安装依赖,这就用到了之前的txt文件。
pip install -r requirements.txt
然后我们运行app.py
python app.py
当出现
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
的时候,大功造成。
这时,在浏览器打开上面这个链接,就可以实现图像识别了
但是呢,垃圾系列,效果着实很垃圾。
笑不活了。
整活文章,别太当回事。