Part 1.1 开始
从Kaggle Notebook《Is it a bird? Creating a model from your own data 》作为开始。简明扼要的用fine-tune ResNet预测鸟的例子作为切入点,同时展示了FastAI的几个主要库:
- fastcore: 一个用于简化 Python 编程和提高代码可读性的库
- fastdownload:简单的文件下载库,同时支持检查功能
- fastai.vision: 很强大的库,使得处理计算机视觉任务变得更加简单和快速。
1. 下面代码首先定义一个方法从DDG网站获取数据
from duckduckgo_search import ddg_images
from fastcore.all import *
from fastdownload import download_url
def search_images(term, max_images=30):
print(f"Searching for '{term}'")
return L(ddg_images(term, max_results=max_images)).itemgot('image')
其中duckduckgo_search函数被用于在 DuckDuckGo 搜索引擎中搜索图片
2. 展示例子图片
from fastai.vision.all import *
im = Image.open(dest)
im.to_thumb(256,256)
3. 获取数据集
searches = 'forest','bird'
path = Path('bird_or_not')
from time import sleep
for o in searches:
dest = (path/o)
dest.mkdir(exist_ok=True, parents=True)
download_images(dest, urls=search_images(f'{o} photo'))
sleep(10) # Pause between searches to avoid over-loading server
download_images(dest, urls=search_images(f'{o} sun photo'))
sleep(10)
download_images(dest, urls=search_images(f'{o} shade photo'))
sleep(10)
# 调用 resize_images 函数,将 path/o 路径上的所有图片的大小调整为 400 像素,并将处理后的图片保存在 path/o 路径上。
resize_images(path/o, max_size=400, dest=path/o)
4. 检查问题数据
# 检查这些图片文件是否可以正常打开。所有无法打开的图片文件会被返回,并赋值给 failed
failed = verify_images(get_image_files(path))
# 删除所有无效的图片文件
failed.map(Path.unlink)
len(failed)
5. 生成数据集
dls = DataBlock(
#这行代码定义了两个 blocks,一个是 ImageBlock,用于处理图像数据,另一个是 CategoryBlock,用于处理标签(类别)数据。
blocks=(ImageBlock, CategoryBlock),
#这行代码指定了用于获取图片文件的函数
get_items=get_image_files,
# 随机分割数据,其中 20% 的数据被用作验证集
splitter=RandomSplitter(valid_pct=0.2, seed=42),
# 指定了用于获取每个图片的类别(即标签)的函数: 目录上一层
get_y=parent_label,
# 将每张图片调整为 192x192 像素大小,使用 'squish' 方法进行缩放,这意味着图片会被压缩或拉伸以适应指定的尺寸,而不是被剪裁
item_tfms=[Resize(192, method='squish')]
).dataloaders(path, bs=32)
# 显示一个批次中的几个图像
dls.show_batch(max_n=6)
6. 模型训练
learn = vision_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(3)
7. 模型预测
dest = 'bird.jpg'
urls = search_images('bird photos', max_images=1)
download_url(urls[0], dest, show_progress=False)
is_bird,_,probs = learn.predict(PILImage.create('bird.jpg'))
print(f"This is a: {is_bird}.")
print(f"Probability it's a bird: {probs[0]:.4f}")