网页游戏架设_如何从头架设一个Google Quickdraw Demo

d678e4a0e8451574e2620a9138762c9b.png

最近在做ITP的毕业作品,我的设想是做一个用涂鸦+AI识别的方式来玩的游戏。类似Scribblenauts, 但是输入词语的方式是用涂鸦和AI识别。

Google并没有开放涂鸦识别的API,只开放了数据集,所以我需要自己做一个服务器来提供涂鸦识别服务。

这周把服务器架设好了,在这里分享一下我的实现方法,也方便后面的人参考。

完成的代码在这里:

EonYang/flaskServer​github.com
f7730ec3094c337f84021ae466ce1d67.png

我的Demo在这里:

p5.js Web Editor​editor.p5js.org

我并不是Machine Learning的专家,Python也是今年才开始学。其中参考,借用,copy了很多别人的代码和模型。

第零步,设想:

  1. 我需要一个模块,运行在云上,把所有输入的图片变成词语。比如我画了一个小鸟,图片进去,“bird”出来。所以我需要一个Tensorflow的model。
  2. 我需要把它做成一个API,我可以在任何地方发图片过去,并且拿到识别结果。不需要很强大,用Flask就行了。
  3. 随便在什么地方,例如网页,桌面应用,Unity游戏里,请求上面完成的API,并获得结果。

第一步:Tensorflow和model。

值得庆幸的是,Kaggle组织了一个QuickDraw训练的比赛,很多人参与并公开了自己的实现方式和模型。我直接借用了排名第一的代码和模型。原文和链接在这里。

通过读他的代码,可以看出他做了什么:

  1. 加载Keras提供的Pre-trained Mobilenet模型。
  2. 读数据集,数据集里面都是笔划,用opencv把笔划画到图片上。
  3. 用画出的图片去train Keras的Mobilenet,得到新的Weight。
  4. 然后写了一些代码去测试他的model的表现。

我电脑烂,偷懒不想自己Train模型。直接下载他的model.h5文件。

同时我们还需要借用的是他用来Prepare图像的function。需要注意的是,这个Model必须输入黑色背景白色笔划的图片。

环境: Python 3.5.5, Tensorflow和Keras当前最新的不知道什么版本。

1. 现在开始做了。先建立一个predictor.py, 导入用得上的library,一大堆。

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import os
import json
import datetime as dt
import cv2
import base64
import io
from PIL import Image
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation
from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy, categorical_crossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input
from tensorflow.keras import backend as K

这里面很多是测试时候用过的libray,后面好像有些测试代码删了,所有有些library并没有用上,但是我懒得整理了。

核心library有几个:

  • numpy和panda,用来处理原始的数据集。不过用到panda的代码好像已经被我删了。
  • cv2,图片处理必备
  • tensorflow和keras,这个project的核心library。

2. 定义一些后面需要的Variables,暂时不用搞懂他们是干什么的。

BASE_SIZE = 256
NCSVS = 100
NCATS = 340
np.random.seed(seed=1987)
tf.set_random_seed(seed=1987)

def top_3_accuracy(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=3)

STEPS = 800
EPOCHS = 16
size = 64
batchsize = 680

3. Load and compile我们的模型和weight,同时创建我们的Session和Graph。

把前面下载到的model.h5放到model文件夹下面。 真正的model是从Keras导入的Mobilenet,这个model.h5只是一个weight文件,要用导入Mobilenet后,用load_weights来加载。

def init():
    sess = tf.InteractiveSession()
    loaded_model = MobileNet(input_shape=(size, size, 1), alpha=1., weights=None, classes=NCATS)
    loaded_model.load_weights("./model/model.h5")
    loaded_model.compile(optimizer=Adam(lr=0.002), loss='categorical_crossentropy',
                  metrics=[categorical_crossentropy, categorical_accuracy, top_3_accuracy])
    print(loaded_model.summary())
    graph = tf.get_default_graph()
    return loaded_model, sess, graph

global model, sess, graph
model, sess, graph = init()

要不要测试一下?测试代码已经被我删了,直接往下写吧。

4. 这个模型只能输入64*64的,单色的,黑底白笔划的图片,所以我们写个function,用opencv处理一下收到的图片先。

def prepareImage(im):
    #gray
    im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    #binary
    thresh = 127
    im = cv2.threshold(im, thresh, 255, cv2.THRESH_BINARY)[1]
    # see if need to invert
    n_white_pix = np.sum(im == 255)
    n_black_pix = np.sum(im == 0)
    if n_white_pix > n_black_pix:
        im = cv2.bitwise_not(im)
    #trim1, move content to the left-up corner;
    size = len(im[0])
    sum0 = im.sum(axis = 0)
    sum1 = im.sum(axis = 1)
    for i in range(len(sum0)):
        if sum0[i] == 0:
            im = np.delete(im, 0, 1)
            zero = np.zeros((size,1))
            im = np.append(im,zero,1)
        else :
            break
    for i in range(len(sum1)):
        if sum1[i] == 0:
            im = np.delete(im, 0, 0)
            zero = np.zeros((1,size))
            im = np.append(im,zero,0)
        else :
            break
    # trim2 crop content
    sum3 = im.sum(axis = 0)
    sum4 = im.sum(axis = 1)
    x2 = 1
    y2 = 1
    while x2 < len(sum3) and sum3[-x2] ==0:
        x2 += 1
    while y2 < len(sum4) and sum4[-y2] ==0:
        y2 += 1
    w = size - x2
    h = size - y2
    contentSize = w if w > h  else  h
    # only crop if there is realy content
    if contentSize > 16:
        im = im[0:contentSize, 0:contentSize]
    return im

这个function做的事情是:

  • 不管什么图片,先变成黑白灰的,然后再处理成黑白单色(只有纯黑和纯白)的。
  • 看看有多少白色的像素,和多少黑色的像素。哪个多就说明哪个是背景色。
  • 如果背景色是白色,反相它。如果背景色是黑色,什么也不做。
  • 把图片内容移到左上角。
  • 把右边和底部多余的黑色背景切掉,留下一个正方形。之所以这样做,是因为这个模型是用Google提供的Simplified数据训练的。这个数据里面都是多余的背景切掉的正方形。如果我们输入的数据也这样,结果会更准确。
  • 这里暂时没有把图片裁剪成64,因为我想先把原始分辨率的图片储存一下,收集数据,后面再裁剪。

5.写个function来输入图片和返回结果

def prepareImageAndPredict(model, cv2ImageData,size=64):
    try:
        # downsize to 64
        image64 = cv2.resize(cv2ImageData, (64, 64))
        x = np.zeros((1,size, size, 1))
        x[0, :, :, 0] = image64
        x = preprocess_input(x).astype(np.float32)
        prediction = model.predict(x, batch_size=128, verbose=1)
        top5 = np.argsort(-prediction, axis=1)[:, :5]
        return top5[0]

    except Exception as e:
        print(e)
        pass

这里面做的事情是:

  • Resize到64*64.
  • 虽然我们只有一张图片,还是要创建一个4d的array,这样模型才能处理我们的data。只是其中一个维只有一个数据而已。
  • Call model.predict得到Prediction
  • 得到最高confidence的5个结果。
  • 由于我们只有一张图片,所以return top5[0]

6. 现在测试一下模型好了。

imagePath = "./whateverDoodle.jpg"
image = cv2.imread(imagePath)
image = prepareImage(image)

with sess.as_default():
        with graph.as_default():
            prediction= prepareImageAndPredict(model, image).tolist()

print(prediction)

以前的测试代码已经删了,这一段是我临时敲的,没运行过,可能会有错误什么的。

如果成功,会打印出5个数字,0-340之间,每个数字代表一个物品。

数字是以下list的index:

categories = ['airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']

现在如果它work了,我们可以创建一个Flask server了

第二步,建立Flask server:

1. 建立一个server.py,然后按照python惯例,import一大堆library,即使有些已经用不着了。

from flask import Flask , jsonify, request, render_template, send_from_directory
from flask_cors import CORS
from predictor import *
import random
import json
import pandas as pd
import numpy as np
from tensorflow.keras import models
import time
import datetime
import cv2
import sys, getopt
import os
import base64
import io
from PIL import Image

2. 创建server。我的Demo由于偷懒,我必须用Https。但是我假设你不需要https,所以我把https和parse arguments的部分都删了。

app = Flask(__name__)
CORS(app)

CORS可以允许你从别的域名来Ajax请求你的API。也是我偷懒用的,因为我的客户端demo没有放在自己的服务器上,直接托管在了p5js上。如果你的客户端和自己API同一个域名,你不需要这个。

3. 创建第一个route。

def stringToRGB(base64_string):
    imgdata = base64.b64decode(str(base64_string))
    image = Image.open(io.BytesIO(imgdata))
    return cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)

@app.route("/api/doodlePredict", methods=["POST"])
def predictAPI():
    global model, graph
    print("this is the request: ", request.form.to_dict())
    image_raw = request.form.to_dict()["data"]
    image_raw = stringToRGB(image_raw)
    image = prepareImage(image_raw)
    response = {'prediction':{
    'numbers':[],
    'names':[]
    }}
    with sess.as_default():
        with graph.as_default():
            response['prediction']['numbers'] = prepareImageAndPredict(model, image).tolist()
    for i in range(len(response['prediction']['numbers'])):
        response['prediction']['names'].append(categories[response['prediction']['numbers'][i]])
    print("this is the response: ", response['prediction']['names'])
    cv2.imwrite("./doodleHistory/"+ ', '.join(response['prediction']['names']) +", "+datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") +".jpg", image_raw)
    return jsonify(response)

这个route假设,图片文件是被转换成了base64 之后,放在一个form里面的data栏 里,通过POST传过来的。

于是先读form,然后转换成dict,然后从dict里面读data,读完了再转成rgb图片。

然后call我们的prepareImage,确保图片是黑底白笔划,没有多余背景。

然后就和上面测试一样,Sess和Graph用上,callmodel.predict()

我顺便把传过来的Doodle图像给存起来了。

然后把得到的结果弄成一个JSON,return回去。

4. 最后,app.run()

if __name__ == "__main__":
    app.run(host = "0.0.0.0", port = 5800, debug = True)

注意必须定义host = "0.0.0.0"才可以从外网访问这个API。

第三步,写个客户端来用我们的API

我直接在p5js的在线editor里面写了,因为这样比较快。这里不按模块一段一段写了,只写写思路,和放出代码。

p5.js Web Editor​editor.p5js.org
  • 创建一个p5的画布。
  • 创建一个2d array来存所有的stroke,stroke本身是一个1d的array。
  • 从第0条stroke开始。每次按下鼠标,把当前鼠标位置存进当前stroke。
  • 每次松开鼠标,创建新的空stroke。
  • 每秒60帧的把所有的stroke画在画布上。
  • 松开鼠标的时候,代表完成了一条新的stroke。这个时候把canvas图片转成base64的string,放在一个form里,用JQuery把数据POST到我们前面做好的API。
  • 收到返回的JSON后,把结果显示在网页上。

今天先写这么多,回头补图。

游戏正在制作中,核心模块已经快完成,有空的话我会写一篇来分享我的游戏是怎么做的而且怎么用ML的。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值