树莓派部署Mobilenet分类网络(二)

在树莓派中创建Thommy文件将下列,将电脑上训练的代码通过VNC传输到树莓派上,只需要把模型和json格式的类别的路径放入相应位置后,就可以实现单片机发送'e',树莓派接收到以后直接发送类别回传给单片机

import os
import json
import torch
import cv2
import serial
from PIL import Image
from torchvision import transforms
from model_v3 import mobilenet_v3_large

def preprocess_image(image):
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)
    return img

def setup_serial(port):
    ser = serial.Serial(port, 9600, timeout=1)
    return ser

def handle_serial(ser, class_indict, model, device):
    if ser.in_waiting > 0:
        received_data = ser.readline().decode().strip()
        print("Received data:", received_data)

        if received_data == 'e':
            # Capture image from webcam
            cap = cv2.VideoCapture(0)
            ret, frame = cap.read()
            cap.release()

            if not ret:
                print("Failed to capture image")
                return None

            # Preprocess image
            img = preprocess_image(frame)

            with torch.no_grad():
                output = torch.squeeze(model(img.to(device))).cpu()
                predict = torch.softmax(output, dim=0)
                predict_cla = torch.argmax(predict).item()

            print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                         predict[predict_cla].item())
            print(print_res)

            return print_res

    return None

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # read class_indict
    json_path = '/home/pi/Desktop/MobilenetV3/class_indices.json'
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = mobilenet_v3_large(num_classes=2).to(device)
    # load model weights
    model_weight_path = '/home/pi/Desktop/MobilenetV3/MobileNetV3.pth'
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()

    # Initialize serial port
    ser = setup_serial('/dev/ttyAMA0')  # Replace with your actual serial port

    try:
        while True:
            # Handle serial communication
            response = handle_serial(ser, class_indict, model, device)
            if response:
                ser.write((response + "\n").encode())
                print("Sent response:", response)

            # Display and process frames from webcam (optional)
            # cv2.imshow('MobileNet Classification', frame)
            # if cv2.waitKey(1) & 0xFF == ord('q'):
            #     break
    finally:
        ser.close()
        # cv2.destroyAllWindows()  # Uncomment if using imshow

if __name__ == '__main__':
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值