本次暑期科研见习,我有机会初步了解了人工智能的深度学习和模型压缩的基本内容,并在移动设备(树莓派3B)上进行了一些简单的深度学习模型训练。在见习结束之际,总结一下这次学习的内容,也期待之后能够继续在相关领域进行更为深入的研究。
一、深度学习的模型剪枝初探
参考:https://jacobgil.github.io/deeplearning/pruning-deep-learning
一般来讲,随着深度学习的神经网络层数越来越多、网络越来越宽,深度学习模型得到的结果会越来越精细。但与此同时,模型的参数量和计算量也会呈现激增的态势。这不仅对硬件的性能形成了挑战,同时一些冗余的模型参数也会影响计算效率。因此,就需要对一些大模型进行压缩。模型压缩大体上分为知识蒸馏、模型剪枝和量化三类方法,这里重点介绍模型剪枝。
模型剪枝的原理就是希望通过剪除对输出结果贡献不大的参数,减小模型的规模、提升运行速度,同时可以保持模型性能基本不变。
基本步骤:首先,根据对结果贡献度(weight)的大小对神经元进行排序;然后,舍去那些贡献度低的神经元,使模型的规模更精简,模型运行速度更快。当然,这里面有几点问题需要进行说明:
1.剪枝后的再训练
进行剪枝的目的之一是希望剪枝带来的模型损失(cost)越小越好.因此,剪枝后的模型需要再进行反复训练直到呈现令人满意的性能。所以,模型的剪枝实际上是一个迭代的过程,这通常称为“迭代式剪枝”;迭代的过程就是剪枝和模型训练两者的交替重复。
2.几种剪枝技术的简单介绍
不同的剪枝技术不仅包括对神经网络卷积层的处理,也包括如何选取模型参数的权重函数。
论文1:《Pruning filters for effecient convnets》
地址:https://arxiv.org/abs/1608.08710
本文提出对卷积层进行完全的剪枝。 作者提出了基于量级的裁剪方式,用weight值的大小来评判其重要性,对于一个filter,其中所有weight的绝对值求和,来作为该filter的评价指标,将一层中值低的filter裁掉,可以有效的降低模型的复杂度,并且不会给模型的性能带来很大的损失。
对卷积窗口剪枝的迭代过程中,每一轮迭代会将全部的卷积窗口进行排序(排序指标为卷积核中L1正则化的权重参数),舍弃排序后指标最低的m个卷积窗口以达到剪枝的目的,然后用剪枝后的卷积窗口进行模型训练,再不断地重复这个过程。
论文二:《Structured Pruning of Deep Convolutional Neural Networks》 地址:https://arxiv.org/abs/1512.08571
这篇论文与上一篇类似,不过在**排序上用了更加复杂的方法。论文采用了N个卷积单元过滤器 (Particle Filters)来对相应的N个卷积层进行剪枝操作。**每一个卷积单元会根据其影响模型在验证数据集上的准确率程度而被分配一个分值,分值低的卷积单元会被过滤掉以达到剪枝的目的。不过这种剪枝非常耗时。
论文三:《Pruning Convolutional Neural Networks for Resource Efficient Inference》地址:https://arxiv.org/abs/1611.06440
本文将剪枝问题当作是一个组合优化问题:从众多的权重参数中选择一个最优组合B,使得被剪枝的模型的代价函数损失最小。相应公式如下:
值得注意的是,论文用的是代价函数损失的绝对值,而不是单纯的差值。使用代价函数损失的绝对值作为优化目标,可以保证被剪枝的模型在性能上不会损失过多。
二、树莓派上简单深度学习模型的训练
当然,模型剪枝的一个设想就是能够把模型放在比较小的设备上能够运行。小型移动设备的配置和计算性能相对来讲都略逊一筹,不过作为一个训练小型深度学习模型的载体还是足够的。
树莓派系统和模块的配置
本次使用的是树莓派(Raspberry pi)3B。树莓派是一款基于ARM架构的微型电脑主板,以SD/MicroSD卡为内存硬盘,卡片主板周围有1/2/4个USB接口和一个10/100 以太网接口(A型没有网口),可连接键盘、鼠标和网线,同时拥有视频模拟信号的电视输出接口和HDMI高清视频输出接口,以上部件全部整合在一张仅比信用卡稍大的主板上,具备所有PC的基本功能只需接通显示屏、鼠标和键盘,就能执行一些简单的功能。树莓派自带的Raspbian系统基于Linux,系统默认的python版本是python2.7&3.7 。
同时,训练深度学习模型需要安装pytorch模块。PyTorch是美国互联网巨头Facebook在深度学习框架Torch的基础上使用Python重写的一个全新的深度学习框架,它更像NumPy的替代产物,不仅继承了NumPy的众多优点,还支持GPU计算,在计算效率上要比NumPy有更明显的优势;不仅如此,PyTorch还有许多高级功能,比如拥有丰富的API,可以快速完成深度神经网络模型的搭建和训练。
关于在树莓派上安装pytorch以及相关模块请参考前一篇文章:
https://blog.csdn.net/qq_44635669/article/details/96972336
当然你也可以尝试在anaconda上尝试安装pytorch模块进行模型训练。
opencv目标检测预训练(predict)模型演示
这个模型的基本原理就是把图片中的一些元素框出,和元素库中的标签元素进行比对识别。
import cv2
import time
# Pretrained classes in the model
classNames = {0: 'background',
1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus',
7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant',
13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat',
18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',
24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag',
32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',
37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',
41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle',
46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',
51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',
61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed',
67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',
75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven',
80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock',
86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}
def id_class_name(class_id, classes):
for key, value in classes.items():
if class_id == key:
return value
# Loading model
time_start=time.time()
model = cv2.dnn.readNetFromTensorflow('models/frozen_inference_graph.pb',
'models/ssd_mobilenet_v2_coco_2018_03_29.pbtxt')
image = cv2.imread("image.jpeg")
model.setInput(cv2.dnn.blobFromImage(image, size=(300, 300), swapRB=True))
output = model.forward()
# print(output[0,0,:,:].shape)
for detection in output[0, 0, :, :]:
confidence = detection[2]
if confidence > .5:
class_id = detection[1]
class_name=id_class_name(class_id,classNames)
print(str(str(class_id) + " " + str(detection[2]) + " " + class_name))
box_x = detection[3] * image_width
box_y = detection[4] * image_height
box_width = detection[5] * image_width
box_height = detection[6] * image_height
cv2.rectangle(image, (int(box_x), int(box_y)), (int(box_width), int(box_height)), (23, 230, 210), thickness=1)
cv2.putText(image,class_name ,(int(box_x), int(box_y+.05*image_height)),cv2.FONT_HERSHEY_SIMPLEX,(.005*image_width),(0, 0, 255))
time_end=time.time()
time_run=time_end-time_start
time_predict=time_end-time_begin
cv2.imshow('image', image)
# cv2.imwrite("image_box_text.jpg",image)
print('predict time:',time_predict)
print('run time:',time_run)
cv2.waitKey(0)
cv2.destroyAllWindows()
最后在jupyter notebook上的运行结果:
1.0 0.6592765 person
18.0 0.8725562 dog
predict time: 0.10614347457885742
run time: 0.526254415512085
分别显示了识别的元素key和value、accuracy rate以及代码运行时间和模型预测时间。
当然,代码运行完毕后会自动弹出一个窗口显示识别的结果。
关于此模型的详细内容请参考:
https://heartbeat.fritz.ai/real-time-object-detection-on-raspberry-pi-using-opencv-dnn-98827255fa60