在树莓派中创建Thommy文件将下列,将电脑上训练的代码通过VNC传输到树莓派上,只需要把模型和json格式的类别的路径放入相应位置后,就可以实现单片机发送'e',树莓派接收到以后直接发送类别回传给单片机
import os
import json
import torch
import cv2
import serial
from PIL import Image
from torchvision import transforms
from model_v3 import mobilenet_v3_large
def preprocess_image(image):
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
return img
def setup_serial(port):
ser = serial.Serial(port, 9600, timeout=1)
return ser
def handle_serial(ser, class_indict, model, device):
if ser.in_waiting > 0:
received_data = ser.readline().decode().strip()
print("Received data:", received_data)
if received_data == 'e':
# Capture image from webcam
cap = cv2.VideoCapture(0)
ret, frame = cap.read()
cap.release()
if not ret:
print("Failed to capture image")
return None
# Preprocess image
img = preprocess_image(frame)
with torch.no_grad():
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).item()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].item())
print(print_res)
return print_res
return None
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# read class_indict
json_path = '/home/pi/Desktop/MobilenetV3/class_indices.json'
assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = mobilenet_v3_large(num_classes=2).to(device)
# load model weights
model_weight_path = '/home/pi/Desktop/MobilenetV3/MobileNetV3.pth'
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
# Initialize serial port
ser = setup_serial('/dev/ttyAMA0') # Replace with your actual serial port
try:
while True:
# Handle serial communication
response = handle_serial(ser, class_indict, model, device)
if response:
ser.write((response + "\n").encode())
print("Sent response:", response)
# Display and process frames from webcam (optional)
# cv2.imshow('MobileNet Classification', frame)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
finally:
ser.close()
# cv2.destroyAllWindows() # Uncomment if using imshow
if __name__ == '__main__':
main()