准备工作
进入 colab notebook后输入以下代码:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
它会让你点进一个网址,输入安全码(在手机google play可找)便可进行后续操作
创建文件夹
创建两个文件夹,一个cat,一个dog。
from fastai.vision import *
path = Path(root_dir + 'DeepLearning/Datasets2/')
dest1 = path/'cat'
dest1.mkdir(parents=True, exist_ok=True)
dest2 = path/'dog'
dest2.mkdir(parents=True, exist_ok=True)
构建数据集
在Google图片上搜cat,浏览一会儿,按F12,再点console,或者Ctrl+Shift+J
打开控制台,直接粘贴以下五段代码:
function simulateRightClick( element ) {
var event1 = new MouseEvent( 'mousedown', {
bubbles: true,
cancelable: false,
view: window,
button: 2,
buttons: 2,
clientX: element.getBoundingClientRect().x,
clientY: element.getBoundingClientRect().y
} );
element.dispatchEvent( event1 );
var event2 = new MouseEvent( 'mouseup', {
bubbles: true,
cancelable: false,
view: window,
button: 2,
buttons: 0,
clientX: element.getBoundingClientRect().x,
clientY: element.getBoundingClientRect().y
} );
element.dispatchEvent( event2 );
var event3 = new MouseEvent( 'contextmenu', {
bubbles: true,
cancelable: false,
view: window,
button: 2,
buttons: 0,
clientX: element.getBoundingClientRect().x,
clientY: element.getBoundingClientRect().y
} );
element.dispatchEvent( event3 );
}
function getURLParam( queryString, key ) {
var vars = queryString.replace( /^\?/, '' ).split( '&' );
for ( let i = 0; i < vars.length; i++ ) {
let pair = vars[ i ].split( '=' );
if ( pair[0] == key ) {
return pair[1];
}
}
return false;
}
function createDownload( contents ) {
var hiddenElement = document.createElement( 'a' );
hiddenElement.href = 'data:attachment/text,' + encodeURI( contents );
hiddenElement.target = '_blank';
hiddenElement.download = 'urls.txt';
hiddenElement.click();
}
function grabUrls() {
var urls = [];
return new Promise( function( resolve, reject ) {
var count = document.querySelectorAll(
'.isv-r a:first-of-type' ).length,
index = 0;
Array.prototype.forEach.call( document.querySelectorAll(
'.isv-r a:first-of-type' ), function( element ) {
// using the right click menu Google will generate the
// full-size URL; won't work in Internet Explorer
// (http://pyimg.co/byukr)
simulateRightClick( element.querySelector( ':scope img' ) );
// Wait for it to appear on the <a> element
var interval = setInterval( function() {
if ( element.href.trim() !== '' ) {
clearInterval( interval );
// extract the full-size version of the image
let googleUrl = element.href.replace( /.*(\?)/, '$1' ),
fullImageUrl = decodeURIComponent(
getURLParam( googleUrl, 'imgurl' ) );
if ( fullImageUrl !== 'false' ) {
urls.push( fullImageUrl );
}
// sometimes the URL returns a "false" string and
// we still want to count those so our Promise
// resolves
index++;
if ( index == ( count - 1 ) ) {
resolve( urls );
}
}
}, 10 );
} );
} );
}
grabUrls().then( function( urls ) {
urls = urls.join( '\n' );
createDownload( urls );
} );
同样的操作搜索dog。
把下载的txt改名,为后续工作做准备
下载数据集
将cat.txt和dog.txt文件上传于path目录下,执行如下命令:
其中使用了fast.ai.vision.data的
download_images
,指定txt文件的位置,存放图片的目的文件夹,最大的图片数目,最大处理线程数目
download_images(path/'cat.txt', dest1, max_pics=120)
download_images(path/'dog.txt', dest2, max_pics=120)
数据清理
下载下来的图片格式(jpeg、png、gift等)大小都不一致,需要进行清理,fast.ai提供了一个verify_images
的工具,可以对图像进行基本的清理操作:
verify_images 会查看该图片是否损坏、是否使用合适的channel数目,是否需要调整到指定大小或超过了限定大小
classes = ['cat','dog']
for c in classes:
print(c)
verify_images(path/c, delete=True, max_size=500)
查看部分图片
fast.ai通过类ImageDataBunch
进行数据的组织
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
可以随机查看一组数据
data.show_batch(rows=3, figsize=(7,6))
训练模型
使用预训练模型resnet34进行迁移学习:
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
解冻模型,修改学习率,再训练几轮:
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
调整学习率继续学习
learn.fit_one_cycle(10,max_lr=slice(1e-4,3e-4))
learn.save('stage-2')
fast.ai提供了ClassificationInterpretation
能够从结果中高效的进行分析
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)
interp.plot_top_losses(9, figsize=(15,11))
参考网址:
https://medium.com/dain-studios/creating-a-computer-vision-api-in-60-minutes-658ff64ae4f7