简介
Term-1第二节课是进行交通标志分类,数据集主要来自于German Traffic Sign,包含了42种交通标志,通过深度学习网络进行分类。
环境准备
- python 2.7
- numpy
- scikit-learn
- tensorflow
- keras
处理流程
处理流程如下图所示
数据读取
我们拿到的数据集是一系列交通标志图像,每个类别的交通标志放在了同一个文件夹下,并且有一个csv文件用于描述每个交通标志图片的ROI区域和该标志所属类别。下载图像的时候网站提供了一份用于数据处理的python程序(Python code for GTSRB文件夹下),在这里可以用到。
这里按csv描述的图像信息进行ROI部分的提取,并使用pickle保存为.p文件方便后续模型训练使用。代码示例如下:
import numpy as np
import pickle
import os
import cv2
import csv
//处理训练数据
def process_train_data(path):
file = os.listdir(path)
classes = len(file)
train_data = []
train_labels = []
for i in range(0,classes):
dir_name = file[i]
if dir_name=='.DS_Store':
continue
full_dir_path = path + dir_name
csv_file_path = full_dir_path + '/' + 'GT-{0}.csv'.format(dir_name)
with open(csv_file_path) as f:
csv_reader =