两周多的努力总算写出了RCNN的代码,这段代码非常有意思,并且还顺带复习了几个Tensorflow应用方面的知识点,故特此总结下,带大家分享下经验。理论方面,RCNN的理论教程颇多,这里我不在做详尽说明,有兴趣的朋友可以看看这个博客以了解大概。
系统概况
RCNN的逻辑基于Alexnet模型。为增加模型的物体辨识率,在图片未经CNN处理前,先由传统算法(文中所用算法为Selective Search算法)取得大概2000左右的疑似物品框。之后,这些疑似框被导入CNN系统中以取得输出层前一层的特征后,由训练好的svm来区分物体。这之中,比较有意思的部分包括了对经过ImageNet训练后的Alexnet的fine tune,对fine tune后框架里输出层前的最后一层特征点的提取以及训练svm分类器。下面,让我们来看看如何实现这个模型吧!
代码解析
为方便编写,这里应用了tflearn库作为tensorflow的一个wrapper来编写Alexnet,关于tflearn,具体资料请点击这里查看其官网。
那么下面,让我们先来看看系统流程:
第一步,训练Alexnet,这里我们运用的是github上tensorflow-alexnet项目。该项目将Alexnet运用在学习flower17数据库上,说白了也就是区分不同种类的花的项目。github提供的代码所有功能作者都有认真的写出,不过在main的写作以及对模型是否支持在断点处继续训练等问题上作者并没写明,这里贴上我的代码:
def train(network, X, Y):
# Training
model = tflearn.DNN(network, checkpoint_path='model_alexnet',
max_checkpoints=1, tensorboard_verbose=2, tensorboard_dir='output')
# 这里增加了读取存档的模式。如果已经有保存了的模型,我们当然就读取它然后继续
# 训练了啊!
if os.path.isfile('model_save.model'):
model.load('model_save.model')
model.fit(X, Y, n_epoch=100, validation_set=0.1, shuffle=True,
show_metric=True, batch_size=64, snapshot_step=200,
snapshot_epoch=False, run_id='alexnet_oxflowers17') # epoch = 1000
# Save the model
# 这里是保存已经运算好了的模型
model.save('model_save.model')
同时,我们希望可以检测模型是否运作正常。以下是检测Alexnet用代码
# 预处理图片函数:
# ------------------------------------------------------------------------------------------------
# 首先,读取图片,形成一个Image文件
def load_image(img_path):
img = Image.open(img_path)
return img
# 将Image文件给修改成224 * 224的图片大小(当然,RGB三个频道我们保持不变)
def resize_image(in_image, new_width, new_height, out_image=None,
resize_mode=Image.ANTIALIAS):
img = in_image.resize((new_width, new_height), resize_mode)
if out_image:
img.save(out_image)
return img
# 将Image加载后转换成float32格式的tensor
def pil_to_nparray(pil_image):
pil_image.load()
return np.asarray(pil_image, dtype="float32")
# 网络框架函数:
# ------------------------------------------------------------------------------------------------
def create_alexnet(num_classes):
# Building 'AlexNet'
network = input_data(shape=[None, 224, 224, 3])
network = conv_2d(network, 96, 11, strides=4, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 256, 5, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
<