构建图片分类器
依赖包(Libraries)
下载依赖包
pip3 install fastai bing_image_downloader
导入依赖包
from time import sleep
from fastcore.all import *
from fastai.vision.all import *
from bing_image_downloader import downloader
训练器(Dataloader)
准备数据
"target"为分类对象,可自行修改
这里的“50”为下载的训练图片的数量,可自行修改(建议至少50张)
因为使用bing下载图片,有部分图片国内网络会下载失败,因此下载完成总时长15分钟左右
folder1 = Path('test_set')
folder2 = Path('train_set')
target = 'earth', 'moon'
for sort in target:
downloader.download(f'{sort}', 1, 'test_set', verbose=False, timeout=10)
downloader.download(f'{sort}', 50, 'train_set', verbose=False, timeout=10)
resize_images(folder1/sort, max_size=400, dest=folder1/sort)
resize_images(folder2/sort, max_size=400, dest=folder2/sort)
创建训练器
blocks = ( input_block, output_block ), 这里是输入图片,输出类别
get_items 是文件种类,这里是图片文件
get_y 是类别标签,这里指的是文件夹名称作为类别标签
item_tfms 是item_transforms, 这里是统一图片大小(以便训练)
splitter 用于将下载的图片分成training和validation set,这里是用的随机划分,0.2指的是将20%划分为validation set
DLS = DataBlock(
blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y = parent_label,
item_tfms = Resize(224),
splitter = RandomSplitter(valid_pct=0.2, seed=42)
).dataloaders(folder, bs=32)
检查训练器
显示训练数据集中的6张图片,确认图片的质量并检查标签是否正确
DLS.show_batch(max_n=6)
分类器(Classifier)
训练分类器
这里的“20”为训练次数,可自行修改(由于这里训练图片较少,因此训练多次)
earth_model = vision_learner(DLS, resnet18, metrics=error_rate)
earth_model.fine_tune(20)
检验分类器
这里的earth.jpg是test_set文件夹里的(默认是Image_1.jpg),自行重命名后需要移动到图片分类器代码文件所在的目录下
label,_,probs = earth_model.predict(PILImage.create('earth.jpg'))
print(f"This is {label}.")
print(f"Probability of it's the earth: {probs[0]:.4f}")