本文将利用 TorchVision Faster R-CNN 预训练模型,于 Kaggle: 全球小麦检测 🌾 上实践迁移学习中的一种常用技术:微调(fine tuning)。
本文相关的 Kaggle Notebooks 可见:
如果你没有 GPU ,也可于 Kaggle 上在线训练。使用介绍:
那么,我们开始吧 💪
准备数据
import os
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from PIL import Image
下载数据
Kaggle: 全球小麦检测 Data
页下载数据,内容如下:
- train.csv - the training data
- sample_submission.csv - a sample submission file in the correct format
- train.zip - training images
- test.zip - test images
DIR_INPUT = 'global-wheat-detection'
DIR_TRAIN = f'{DIR_INPUT}/train'
DIR_TEST = f'{DIR_INPUT}/test'
读取数据
读取 train.csv
内容:
train_df = pd.read_csv(f'{DIR_INPUT}/train.csv')
train_df.head()
- image_id - the unique image ID
- width, height - the width and height of the images
- bbox - a bounding box, formatted as a Python-style list of [xmin, ymin, width, height]
- etc.
把 bbox
替换成 x
y
w
h
:
train_df[['x','y','w','h']] = 0
train_df[['x','y','w','h']] = np.stack(train_df['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=','))).astype(np.float)
train_df.drop(columns=['bbox'], inplace=True)
train_df.head()
分析数据
训练数据大小:
train_df.shape
(147793, 8)
唯一 image_id
数量:
train_df['image_id'].nunique()
3373
train
目录下图片数量:
len(os.listdir(DIR_TRAIN))
3423
说明有 3422-3373=49
张图片没有标注。
训练数据,图片大小:
train_df['width'].unique(), train_df['height'].unique()
(array([1024]), array([1024]))
都是 1024x1024
的。
查看标注数量的分布情况:
counts = train_df['image_id'].value_counts()
print(f'number of boxes, range [{min(counts)}, {max(counts)}]')
sns.displot(counts, kde=False)
plt.xlabel('boxes')
plt.ylabel('images')
plt.title('boxes vs. images')
plt.show()
number of boxes, range [1, 116]
一张图最多的有 116
个标注。
查看标注坐标和宽高的分布情况:
train_df['cx'] = train_df['x'] + train_df['w'] / 2
train_df['cy'] = train_df['y'] + train_df['h'] / 2
ax = plt.subplots(1, 4, figsize=(16, 4), tight_layout=True)[1].ravel()
ax[0].set_title('x vs. y')
ax[0].set_xlim(0, 1024)
ax[0].set_ylim(0, 1024)
ax[1].set_title('cx vs. cy')
ax[1].set_xlim(0, 1024)
ax[1].set_ylim(0, 1024)
ax[2].set_title('w vs. h')
ax[3].set_title('area size')
sns.histplot(data=train_df, x='x', y='y', ax=ax[0], bins=50, pmax=0.9)
sns.histplot(data=train_df, x='cx', y='cy', ax=ax[1], bins=50, pmax=0.9)
sns.histplot(data=train_df, x='w', y='h', ax=ax[2], bins=50, pmax=0.9)
sns.histplot(train_df['w'] * train_df['h'], ax=ax[3], bins=50, kde=False)
plt.show()
把数据集分为训练集和验证集,比例 8:2
:
image_ids = train_df['image_id'].unique()
split_len = round(len(image_ids)*0.8)
train_ids = image_ids[:split_len]
valid_ids = image_ids[split_len:]
train = train_df[train_df['image_id'].isin(train_ids)]
valid = train_df[train_df['image_id'].isin(valid_ids)]
train.shape, valid.shape
((122577, 10), (25216, 10))
预览数据
定义下辅助函数:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i