这个Project的目的是利用神经卷积网络(CNN)来分类(classify)常见的交通标志。CNN 在电脑读图领域已经全面超过了传统的机器学习电脑读图的方法(SVC, OpenCV)。大量的数据是深度学习准确性的保证, 在数据不够的情况下也可以人为的对原有数据进行小改动从而来提高识别的准确度。
- 导入必要的软件包(pickle, numpy, cv2, matplotlib, sklearn, tensorflow, Keras)
# Load pickled data
import pickle
import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.contrib.layers import flatten
### Load the images and plot them here.
import os
# Visualizations will be shown in the notebook.
%matplotlib inline
- 数据来源:
和大部分的机器学习的要求一样, CNN需要大量有label的数据,German Traffic Sign Dataset提供了对于这个project的研究并给出结果可用于比较,数据在这里可以下载到。解压后就可以用Python导入了:
training_file = '/Volumes/SSD/traffic-signs-data/train.p' # change it to your local dir
testing_file = '/Volumes/SSD/traffic-signs-data/test.p' # change it to your local dir
with open(training_file, mode='rb') as f:
train = pickle.load(f)
with open(testing_file, mode='rb') as f:
test = pickle.load(f)
X_train, y_train = train['features'], train['labels']
X_test, y_test = test['features'], test['labels']
# Number of training examples
n_train = X_train.shape[0]
# Number of testing examples.
n_test = X_test.shape[0]
# What's the shape of an traffic sign image?
image_shape = X_train[0].shape
# How many unique classes/labels there are in the dataset?
n_classes = len(set(y_train))
print("Number of training examples =", n_train)
print("Number of testing examples =", n_test)
print("Image data shape =", image_shape)
print("Number of classes =", n_classes)
输出:
Number of training examples = 39209
Number of testing examples = 12630
Image data shape = (32, 32, 3)
Number of classes = 43
从上面我们可以看到有39209个用作训练的图像 和 12630个testing data。39209张照片对于训练CNN来说是不够的(100000张以上是比较理想的数据量), 所以之后要加入data augment 的模块来人为增加数据。 每张图像的大小是是32x32 并且有3个信道。总共有43个不同的label。
我们也可以把每个label对应的图片随机选择一张画出来。
# show a random sample from each class of the traffic sign dataset
rows, cols = 4, 12
fig, ax_array = plt.subplots(rows, cols) # ax_array is a array object consistint of plt object
plt.suptitle('RANDOM SAMPLES FROM TRAINING SET (one for each class)')
for class_idx, ax in enumer