上一章我们用了MobileNet已经训练好的模型进行分类,这一节,我们来讲讲什么是Feature Extractor 特征提取,以及在此基础上训练和分类。
观看本教程的视频:https://www.bilibili.com/video/BV1az4y1Z742?p=2
一、开头一段代码
还是打开ml5.js的在线编程网页:https://chn.ai/ml5.html,录入下面代码,点击运行。第一次运行的时候浏览器会请求您电脑摄像头的权限,点“允许”。
<!DOCTYPE html>
<html lang="en">
<head>
<title>Feature Extractor </title>
<script src='js/ml5.min.js'></script>
</head>
<body onload='pageLoaded()'>
<video id='videoElement' autoplay='true'></video>
<br/>
<input type='button' value='add phone' onclick='addPhone()' />
<input type='button' value='add cup' onclick='addCup()' />
<br/>
<input type='button' value='train' onclick='train()' />
<input type='button' value='classify' onclick='classify()' />
<hr/>
<div id='log' style='font-size: 80%; background-color: #efefef'></div>
<script>
var video = null;
var featureExtractor = null;
var classifier = null;
function pageLoaded() {
initVideo();
initExtractor();
}
function initExtractor() {
featureExtractor = ml5.featureExtractor('MobileNet', function() {
log('model loaded');
classifier = featureExtractor.classification(video, function() {
log('classifier inited');
});
})
}
function initVideo() {
video = document.getElementById('videoElement');
if(!navigator.mediaDevices.getUserMedia) return;
navigator.mediaDevices.getUserMedia({video: true})
.then(function(stream) {
video.srcObject = stream;
log('video ready');
})
.catch(function(err) {
log(err);
})
}
function addPhone() {
classifier.addImage('phone', function() {log('phone added');});
}
function addCup() {
classifier.addImage('cup', function() {log('cup added');});
}
function train() {
classifier.train(function(loss) {
log(loss);
})
}
function classify() {
classifier.classify(function(err, result) {
if(err) {
log(err);
return;
}
// else
log(JSON.stringify(result));
})
}
function log(content) {
var elem = document.getElementById('log');
if(!elem) return;
elem.innerHTML = content + "<br/>" + elem.innerHTML;
}
</script>
</body>
</html>
这段代码的作用就是通过电脑摄像头录制的视频作为源进行训练。比如放一个杯子在摄像头前,点击“add phone”,就添加了一个训练数据。从不同角度,不同远近多添加几个‘杯子’的样本;然后以同样的办法添加‘手机’的样本。点击‘train’进行训练。等训练停止过后,再将杯子或手机或者其它任何东西放到摄像头前,点击‘classify’进行识别,看看识别分类的效果。
二、原理
为什么不多的图片,一会儿就能训练出一个准确率还算不错的分类器呢?不是听说深度学习需要大量的数据,长时间的烧显卡才能训练处一个模型吗?要理解这个问题,我们可以拿人类的认知来理解一下:我们告诉一个从来没有见过兔子的小孩说有四条小短腿,两个长耳朵,红眼睛,毛茸茸的小动物就是是小白兔,即使这他没有亲眼见过,我保证一见到小白兔这个小孩就能立刻辨别出来。但是假设一个外星小孩来到地球,我们同样告诉它有四条小短腿,两个长耳朵,红眼睛,毛茸茸的小动物就是是小白兔,这个外星小孩看到小白兔它也不认识。因为地球小孩对腿,耳朵,眼睛,毛茸茸这些属性都有先验的知识,他在其他动物那里学到了腿长什么样,耳朵什么样,眼睛又是什么样,而外星小孩没有这些知识。
我们在代码里面生成一个叫featureExtractor的对象,这个feature就是‘特征’,这些特征是mobileNet针对一个叫imageNet的图片集进行了训练。我们知道神经网络通常是由很多层的网络构成的。研究人员发现,构成mobileNet的卷积神经网络,最开始几层是对图像的一些线段,点,边界做一些抽象;后面几层对更深一层的特征进行抽象,比如一些形状,颜色,区域块等;再后面的网络层则对更抽象的特征更活跃,比如看到眼睛,耳朵,人脸等等。刚刚我们的训练代码就好像小孩认识兔子一样,只需要训练那些抽象特征的集合对应什么物体即可,所以只需要相对少的数据集和计算就可以得到不错的训练结果。
ml5训练feature extractor,大部分的训练都已经做好,ml5feature extractor只是训练这一系列的feature是某某物体,这样训练所需要的数据,时间大大缩短,因为我们用自己的图片,只是训练mobileNet网络少量的层。就好像我们告诉一个从来没有见过兔子的小孩说有四条小短腿腿,两个长耳朵,红眼睛,毛茸茸的小动物是小白兔,即使这个小孩没有亲眼见过,他一看到小白兔应该都能辨别出来。因为小孩对腿,耳朵,眼睛,毛茸茸这些属性都有先验的知识,如果我们同样去告诉一个没见过任何动物的小孩,十有八九这个小孩没法分辨小白兔。因为腿,耳朵,眼睛,毛茸茸这些东西他没有任何概念。ml5训练我们自己的图像过程类似,我们只用训练四条小短腿腿,两个长耳朵,红眼睛,毛茸茸的动物是小白兔,四条长腿,长嘴巴,三角耳,大尾巴,灰色皮毛的是大灰狼就行了,而不用训练ml5怎么去识别腿,尾巴,眼睛,颜色和皮毛等等,这些工作在其它人的训练里面已经做了。
一个有趣的工具叫Deep Visualization Toolbox。
它可以把卷积神经网络的不同的层所关注的展示成图像,从此可以发现许多有趣的内容。比如有的神经元关注人脸,这个可以理解;有的神经元关注人物衣服上的褶皱,这个就有些意外了。而这些关注点不是人告诉它要这么做的,而是通过训练,神经元不断调整而自发形成的,是不是很有意思?
三、代码讲解
再来回顾一下前面我们的代码。代码中大部分内容是页面的UI部分,主要就是初始化摄像头并将摄像头内容投射到页面的<video />
元素中。
真正ml5的代码简化如下:
featureExtractor = ml5.featureExtractor('MobileNet', callback);
classifier = featureExtractor.classification(imageSource, callback);
classifier.addImage('label', callback);
classifier.train(...);
classifier.classify(...);
先创建一个ml5.featureExtractor
对象,指定还是用‘MobileNet’模型,这是一个常用的卷积神经网络,ml5以及tensorflow.js已经帮我们实现了这个网络的结构。在初始化这个对象的时候,ml5会从网络上下载网络的结构定义以及预训练神经的权重数据;然后创建featureExtractor.classification(imageSource, callback)
对象,两个参数第一个是图像数据来源,我们例子里面是<video/>
元素;然后调用addImage
方法添加训练数据;然后调用train
方法进行训练;最后通过classify
方法就可以进行图像分类了。
好了,这节课我们学习了用ml5.js的feature extractor训练和分类图像,还简单介绍了一下feature(就是特征,或者是属性),以及feature visualization特征可视化。特征可视化是一个比较有意思的话题,我们有机会可以专门开一个文章来介绍一下这个内容。
下面一章,我们将来学习一下ml5的Regression 回归 & 保存/加载模型
如果大家有任何意见,建议,idea,或者在编码过程中遇到任何问题,欢迎在下边留言,我看到会一一回复各位。谢谢大家!