当tensorflow遇见flappy

当tensorflow遇见flappy

标签: tensorflow python


本文主要目标是用tensorflow实现识别拳头和手掌,然后融合到flappy中,进行控制小鸟的行动。

拳头手掌分类器模型构建

数据集制作

这里由于想学习一个完整的流程,所以数据集也自己采集了,并且这个项目较小,所以采集起来也不是特别麻烦,我就是直接从摄像头捕捉到手部图片,然后做出手掌和拳头的操作,保存到硬盘的两个文件夹中用来区分。代码如下:

"""
Created on 2019/2/9 18:00
@File: collect_data.py
@author: coderwangson
"""
"#codeing=utf-8"
import cv2 as cv

capture = cv.VideoCapture(0)
count = 100
for i in range(count):
    _,img = capture.read()
    img = cv.flip(img,1)
    img = img[50:350,426:640,:]
    cv.imwrite("./bu/"+"1"+str(i)+".jpg",img)
    cv.imshow("1",img)
    print(i)
    cv.waitKey(24)

这里收集了100张图片,并且写入到/bu/这个文件夹下面,说明里面存的都是手掌(布),同样的你可以修改为/shitou/这样你做出石头的动作,就可以进行石头(拳头)数据的收集。

数据集处理

其实正规处理要分为很多,你可以进行数据清洗,数据增强等,这里我为了便于训练就把图片转为灰度并且缩小为64*64的尺寸。之后我们可以把图片的数据信息存储起来,这样可以在后续训练的时候直接从一个文件读取即可,这里为了省事,我就直接用pkl文件存储下来整个数据集信息。

"""
Created on 2019/2/9 18:06
@File: process_data.py
@author: coderwangson
"""
"#codeing=utf-8"
import os
import cv2 as cv
import numpy as np
import pickle
dirname = "./shitou/"
a = []
for f in os.listdir(dirname):
    img = cv.imread(dirname+f)
    img = cv.resize(img,(64,64))
    img = cv.cvtColor(img,cv.COLOR_BGR2GRAY)
    img = np.reshape(img,(64,64,1))
    a.append(img)
a = np.array(a)
with open("shitou.pkl","wb") as f:
    pickle.dump(a,f)

这里就做了一个简单操作,从文件夹下读取所有图片,然后把图片进行缩放处理。最后我们把像素信息存储到一个ndarray中,我们把这个ndarray存储到一个文件shitou.pkl中,这样我们以后使用可以直接从这个文件里拿到那个矩阵了。

除此之外我还写了两个方法,一个是get_data,一个是get_batch,这两个方法是在训练模型的时候,我们可以直接使用get_batch用来得到一个batch_size的数据。

def get_data():
    shitou = None
    shitou_label = None
    bu = None
    bu_label = None
    with open("shitou.pkl","rb") as f:
        shitou = pickle.load(f)
        shitou_label = np.zeros((shitou.shape[0],1))
    with open("bu.pkl", "rb") as f:
        bu = pickle.load(f)
        bu_label = np.ones((bu.shape[0],1))
    data = np.vstack([shitou,bu])
    label = np.vstack([shitou_label,bu_label])
    label = label.astype(int)
    label = np.eye(2)[label].reshape(data.shape[0],2)
    return data,label  
    
def get_batch(batch_size):
    data,label = get_data()
    p = np.random.permutation(data.shape[0])[0:batch_size]
    return data[p],label[p]

在get_data中主要是从pkl中拿出数据,然后把拳头和手掌的数据放在一起,并且为他们生成一个标签,注意标签使用的是one-hot格式,label = np.eye(2)[label].reshape(data.shape[0],2)这句话就是把普通的转为one-hot格式2代表有几个类别。而对于get_batch主要就是得到一个batch的数据,其实就是随机从整个数据集中拿size个数据,这里使用np.random.permutation(data.shape[0])[0:batch_size]生成了data.shape[0]个随机数,然后[0:batch_size]取出了batch_size个即可。

训练模型

上面的前期工作做完后,我们数据集就有了,然后就可以使用tensorflow进行训练然后得到模型了,我们使用CNN构建网络模型,这个大家应该很熟悉,如果不熟悉的话可以参考一步一步实现CNN卷积神经网络使用numpy并对mnist预测 这篇或者自己去网上找一篇怎么构建一个CNN网络类的文章看一下。

我们搭建好了网络结构并且损失函数都定义完成就可以进行训练了,这里要记得保存模型为了以后预测使用,具体代码可以去我的github上看,模型代码对应Model.py tensorflow_flappy

模型的使用

经过上面的训练我们得到了一个模型,我们以后就可以使用这个模型进行预测了,给进去一张图片,然后就能得出是拳头还是手掌。

"""
Created on 2019/2/9 19:16
@File:predict.py
@author: coderwangson
"""
"#codeing=utf-8"
import tensorflow as tf
import numpy as np
import cv2 as cv
import time
def predict(img,sess,x_image,y_,keep_drop):
    y = sess.run(y_, feed_dict={x_image:img, keep_drop: 1.0})
    return np.argmax(y)
# 因为sess加载耗时间,所以在程序中使用,这样我们就能避免每次加载耗时间
start = time.clock()
with tf.Session() as sess:
    saver = tf.train.import_meta_graph("./model/hand.meta")
    saver.restore(sess,tf.train.latest_checkpoint("./model"))
    graph = tf.get_default_graph()
    x_image = graph.get_tensor_by_name("Placeholder:0")
    y_ = graph.get_tensor_by_name("add_3:0")
    keep_drop = graph.get_tensor_by_name("Placeholder_2:0")
    # print(time.clock() - start)
    capture = cv.VideoCapture(0)
    while True:
        ha, img = capture.read()
        img = cv.flip(img,1)
        cv.rectangle(img, (426, 50), (640, 350), (170, 170, 0))
        img = img[50:350, 426:640, :]
        img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
        cv.imshow("img", img)
        img = cv.resize(img, (64, 64)).reshape((1, 64, 64, 1))

        # 0是拳头 1 是手掌
        action = predict(img, sess,x_image,y_,keep_drop)
        if action ==1:
            print("张开")
        elif action==0:
            print("合住")
        cv.waitKey(24)

这里主要就是那个sess获取,还要图的获取是从模型里面得到的,并且通过测试发现这段代码很耗时间,所以我们获取sess就在游戏里面在刚开始的时候就获取一次即可,避免每次调用predict就要加载一下sess。注意的是我们的图片也要进行resize的处理,和我们制作数据集的时候处理类似。

融合到flappy中

flappy游戏

这个我使用的是别人的代码,可以去我上面git上自己去取,我么主要改动的地方在mainGame while(True)这个地方,因为这个地方是游戏的开始地方,所以我们只需要把原来的代码进行改动即可,我们游戏原来是进行捕捉你键盘按键的响应,现在我们则修改成根据你的手势然后自动用win32按下按键,这样你就可以减小代码改动。

融合到游戏里面

主要是要开启一个摄像头,然后我们在游戏的while(True)里面持续捕获所有帧,然后传入到predict中进行预测,如果是拳头,则模拟按下按键,否则则释放按键。

    while True:
        ha, img = capture.read()
        img = cv2.flip(img,1)
        cv2.rectangle(img, (426, 50), (640, 350), (170, 170, 0))
        img = img[50:350, 426:640, :]
        cv2.imshow("img", img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img = cv2.resize(img, (64, 64)).reshape((1, 64, 64, 1))

        # 0是拳头 1 是手掌
        action = predict(img, sess,x_image,y_,keep_drop)
        if action ==1:
            win32api.keybd_event(38,0,win32con.KEYEVENTF_KEYUP,0) #释放按键
        else:
            win32api.keybd_event(38, 0, 0, 0)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值