运行环境
win10+python3.5+gpu版本的caffe
步骤
- 下载数据集
- 将数据集转为lmdb
- 训练
- 测试训练的出来的模型
下载数据集
mnist官网下载下面4个文件
-
train-images-idx3-ubyte.gz: 训练图片集 (9912422 bytes)
-
train-labels-idx1-ubyte.gz: 训练图片集的打标值 (28881 bytes)
-
t10k-images-idx3-ubyte.gz: 测试图片集 (1648877 bytes)
-
t10k-labels-idx1-ubyte.gz: 测试图片集的打标值(4542 bytes)
分别解压这四个文件得到:
t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
train-images.idx3-ubyte
train-labels.idx1-ubyte
它们的结构在mnist网站上有说明
训练图片集的打标值文件 (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 10000 标签值总数
0008 unsigned byte ?? 标签值
0009 unsigned byte ?? 标签值
........
xxxx unsigned byte ?? 标签值
训练图片集文件 (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 10000 图片总数
0008 32 bit integer 28 单张图片的长度像素值数量
0012 32 bit integer 28 单张图片的高度像素值数量
0016 unsigned byte ?? 单像素值
0017 unsigned byte ?? 单像素值
........
xxxx unsigned byte ?? 单像素值
测试图片集的打标值文件 (t10k-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 10000 标签值总数
0008 unsigned byte ?? 标签值
0009 unsigned byte ?? 标签值
........
xxxx unsigned byte ?? 标签值
标签值的范围是0-9
测试图片集文件 (t10k-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 10000 图片总数
0008 32 bit integer 28 单张图片的长度像素值数量
0012 32 bit integer 28 单张图片的高度像素值数量
0016 unsigned byte ?? 单像素值
0017 unsigned byte ?? 单像素值
........
xxxx unsigned byte ?? 单像素值
数据集转换为lmdb
下载的数据集有两对,一对是训练数据图片集和它对应的标签值, 另一对是测试图片集和它对应的标签值,这两对文件的结构是一样的,因此转换为lmdb文件时,可以使用同样的方法
def orgin_to_lmdb(image_file, label_file, lmdb_save_path, force_update=False):
mean_file = '{}.binaryproto'.format(lmdb_save_path)
if os.path.exists(mean_file) and os.path.exists(lmdb_save_path) and force_update == False:
return
try:
shutil.rmtree(lmdb_save_path)
except:
pass
try:
shutil.rmtree(mean_file)
except:
pass
with open(image_file, 'rb') as image_f:
with open(label_file, 'rb') as label_f:
# 读取标签文件头的4个整型
size = struct.calcsize('>2I')
magic, num_items = struct.unpack_from('>2I', label_f.read(size))
print(magic, num_items)
# 读取图片文件头的4个整型
size = struct.calcsize('>4I')
magic, num_images, num_rows, num_columns = struct.unpack_from('>4I', image_f.read(size))
print(magic, num_images, num_rows, num_columns)
map_size = num_images*num_rows*num_columns * 1.5
# 遍历所有图片,将文件列表写入到lmdb中
with lmdb.open(lmdb_save_path,map_size=map_size) as in_db:
with in_db.begin(write=True) as in_txn:
im_size = num_rows * num_columns
label_size = struct.calcsize('>B')
im_idx = 0
while im_idx < num_images:
img_item = struct.unpack_from('>B', label_f.read(label_size))[0]
img_buf = image_f.read(im_size)
datum = caffe_pb2.Datum(
channels=1, # 数据集里面的图片是灰度图,因此通道数设置为1
width=num_columns,
height=num_rows,
label=int(img_item),
data=img_buf
)
in_txn.put('{:0>8d}'.format(im_idx).encode('utf8'), datum.SerializeToString())
im_idx += 1
# 生成mean文件
cmd = '{0} {1} {2}'.format(compute_image_mean, lmdb_save_path, mean_file)
print(cmd)
os.system(cmd)
以下代码可以打开lmdb查看第一张图片
# 查看lmdb的第一张图片
def show_lmdb_first_image(lmdb_save_path):
with lmdb.open(lmdb_save_path, readonly=True) as lmdb_env:
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
datum = caffe_pb2.Datum()
lmdb_cursor.first()
key, value = lmdb_cursor.item()
datum.ParseFromString(value)
label = datum.label
data = caffe.io.datum_to_array(datum)
print(label, datum.channels, data.shape)
image = data.transpose(1, 2, 0)
cv2.imshow('cv2.png', image)
cv2.waitKey(0)
cv2.destroyAllWindows()
使用数据集进行训练
使用caffe代码目录下的examples\mnist\lenet_solver.prototxt
和examples\mnist\lenet_train_test.prototxt
, 需要修改lenet_solver.prototxt
中的网络文件地址为新的lenet_train_test.prototxt
需要修改lenet_train_test.prototxt
的数据层为刚才生成的lmdb地址
solver = caffe.SGDSolver('lenet_solver.prototxt')
solver.solve()
完成之后会产生两个模型文件lenet_iter_5000.caffemodel
和lenet_iter_10000.caffemodel
测试训练的出来的模型
需要先生成一个网络配置文件, 一般是改动训练时用的网络配置文件,这里直接使用examples\mnist\lenet.prototxt
net = caffe.Net(
'lenet.prototxt', # 网络配置文件
caffe.TEST,
weights='lenet_iter_10000.caffemodel' # 训练产生的模型
)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_raw_scale('data', 255)
# transformer.set_channel_swap('data', (2, 1, 0)) # minist用的是灰度图 channel只有1,因此无需转换
# 因为minist的channel是1, 所以需要转为灰度图color=False
im = caffe.io.load_image('3.jpg', color=False) # 打开测试图片
net.blobs['data'].data[0] = transformer.preprocess('data', im)
res = net.forward()
print(res['prob'].argmax())
测试图片是用windows测试工具写的几个数字,需要黑底白字,并且图片大小要改为28*28
有的识别会出错。。。