细致入微的理解ROS中的入门级别之手写数字识别在ROS领域的研究

一种东西只要懂得他是如何工作的,原理是什么,那么与之类似的东西我们都可以一通百通,就比如手写数字识别被称为进入机器学习的hello world,那么我想如果我们要想学习其他的项目的话,我们只要深刻理解了其基本内涵,我想对于机器学习之路就会得心应手的。
下面来看一个完整的在ROS里边运用Tensorflow来实现手写数字的识别:

#!/usr/bin/env python 
# -*- coding: utf-8 -*-
 
import rospy
from sensor_msgs.msg import Image
from std_msgs.msg import Int16
from cv_bridge import CvBridge
import cv2
import numpy as np
import input_data  
import tensorflow as tf

class MNIST():
    def __init__(self):
        image_topic = rospy.get_param("~image_topic", "")

        self._cv_bridge = CvBridge()

        #MNIST数据输入  
        self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  
          
        self.x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量  
        self.W = tf.Variable(tf.zeros([784,10]))        #权重,初始化值为全零  
        self.b = tf.Variable(tf.zeros([10]))            #偏置,初始化值为全零  
          
        #进行模型计算,y是预测,y_ 是实际  
        self.y = tf.nn.softmax(tf.matmul(self.x, self.W) + self.b)  
          
        self.y_ = tf.placeholder("float", [None,10])  
          
        #计算交叉熵  
        self.cross_entropy = -tf.reduce_sum( self.y_*tf.log(self.y))  
        #接下来使用BP算法来进行微调,以0.01的学习速率  
        self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.cross_entropy)  
          
        #上面设置好了模型,添加初始化创建变量的操作  
        self.init = tf.global_variables_initializer()  
        #启动创建的模型,并初始化变量  
        self.sess = tf.Session()  
        self.sess.run(self.init)  

        #开始训练模型,循环训练1000次  
        for i in range(1000):  
            #随机抓取训练数据中的100个批处理数据点  
            batch_xs, batch_ys = self.mnist.train.next_batch(100)  
            self.sess.run(self.train_step, feed_dict={self.x:batch_xs, self.y_:batch_ys})  

        ''''' 进行模型评估 '''  
        #判断预测标签和实际标签是否匹配  
        correct_prediction = tf.equal(tf.argmax(self.y,1),tf.argmax(self.y_,1))   
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  
       
        #计算所学习到的模型在测试数据集上面的正确率  
        print( "The predict accuracy with test data set: \n")
        print( self.sess.run(self.accuracy, feed_dict={self.x:self.mnist.test.images, self.y_:self.mnist.test.labels}) )  

        self._sub = rospy.Subscriber(image_topic, Image, self.callback, queue_size=1)
        self._pub = rospy.Publisher('result', Int16, queue_size=1)

    def callback(self, image_msg):
        #预处理接收到的图像数据
        cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8"
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暗香独自开

你的鼓励是我总结的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值