转自:https://segmentfault.com/a/1190000022771088
查看数据
在构建数据集之前,我们先对数据进行一些可视化,对数据有一个大致的了解。
文件路径如下, 将其保存为字典
data_dir = {
'train_data': '/content/data/mchar_train/',
'val_data': '/content/data/mchar_val/',
'test_data': '/content/data/mchar_test_a/',
'train_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_train.json',
'val_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_val.json',
'submit_file': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_sample_submit_A.csv'
}
查看图片数量
def data_summary():
train_list = glob(data_dir['train_data']+'*.png')
test_list = glob(data_dir['test_data']+'*.png')
val_list = glob(data_dir['val_data']+'*.png')
print('train image counts: %d'%len(train_list))
print('val image counts: %d'%len(val_list))
print('test image counts: %d'%len(test_list))
data_summary()
train image counts: 30000
val image counts: 10000
test image counts: 40000
查看标注文件信息
def look_train_json():
with open(data_dir['train_label'], 'r', encoding='utf-8') as f:
content = f.read()
# loads将字符串转为字典
content = json.loads(content)
print(content['000000.png'])
look_train_json()
{'height': [219, 219], 'label': [1, 9], 'left': [246, 323], 'top': [77, 81], 'width': [81, 96]}
查看结果文件提交格式
def look_submit():
df = pd.read_csv(data_dir['submit_file'], sep=',')
print(df.head(5))
look_submit()
file_name file_code
0 000000.png 0
1 000001.png 0
2 000002.png 0
3 000003.png 0
4 000004.png 0
在图片上查看标注框
def plot_samples():
imgs = glob(data_dir['train_data']+'*.png')
fig, ax = plt.subplots(figsize=(12, 8), ncols=2, nrows=2)
marks = json.loads(open(data_dir['train_label'], 'r').read())
for i in range(4):
img_name = os.path.split(imgs[i])[-1]
mark = marks[img_name]
img = Image.open(imgs[i])
img = np.array(img)
bboxes = np.array(
[mark['left'],
mark['top'],
mark['width'],
mark['height']]
)
ax[i//2, i%2].imshow(img)
for j in range(len(mark['label'])):
# 定义一个矩形
rect = patch.Rectangle(bboxes[:, j][:2], bboxes[:, j][2], bboxes[:, j][3], facecolor='none', edgecolor='r')
ax[i//2, i%2].text(bboxes[:, j][0], bboxes[:, j][1], mark['label'][j])
# 绘制矩形
ax[i//2, i%2].add_patch(rect)
plt.show()
plot_samples()
查看训练图片的长宽分布
def img_size_summary():
sizes = []
for img in glob(data_dir['train_data']+'*.png'):
img = Image.open(img)
sizes.append(img.size)
sizes = np.array(sizes)
plt.figure(figsize=(10, 8))
plt.scatter(sizes[:, 0], sizes[:, 1])
plt.xlabel('Width')
plt.ylabel('Height')
plt.title('image width-height summary')
plt.show()
return np.mean(sizes, axis=0), np.median(sizes, axis=0)
mean, median = img_size_summary()
print('mean: ', mean)
print('median: ', median)
可以看到,训练图片之间的尺寸差异非常大,且基本上都是宽要比高大,宽之间的差异大于高之间的差异。后续确定网络输入大小,可以结合中位数或平均值确定网络输入大小。
查看边界框大小分布
def bbox_summary():
marks = json.loads(open(data_dir['train_label'], 'r').read())
bboxes = []
for img, mark in marks.items():
for i in range(len(mark['label'])):
bboxes.append([mark['left'][i], mark['top'][i], mark['width'][i], mark['height'][i]])
bboxes = np.array(bboxes)
fig, ax = plt.subplots(figsize=(12, 8))
ax.scatter(bboxes[:, 2], bboxes[:, 3])
ax.set_title('bbox width-height summary')
ax.set_xlabel('width')
ax.set_ylabel('height')
plt.show()
bbox_summary()
如果采用目标检测的思路实现字符识别,可以使用Kmeans聚类的方式来对边界框来确定anchor尺寸。
查看不同字符类别的数目
def label_nums_summary():
marks = json.load(open(data_dir['train_label'], 'r'))
dicts = {i: 0 for i in range(10)}
for img, mark in marks.items():
for lb in mark['label']:
dicts[lb] += 1
xticks = list(range(10))
fig, ax = plt.subplots(figsize=(10, 8))
ax.bar(x=list(dicts.keys()), height=list(dicts.values()))
ax.set_xticks(xticks)
plt.show()
return dicts
print(label_nums_summary())
可以看出不同类别之间差别总体差异不大,除了数字1的出现次数较大。没有出现极端不平衡的情况。后期分类可以考虑使用Weighted-CrossEntropy损失。
查看每个图片出现的数字个数
def label_summary():
marks = json.load(open(data_dir['train_label'], 'r'))
dicts = {}
for img, mark in marks.items():
if len(mark['label']) not in dicts:
dicts[len(mark['label'])] = 0
dicts[len(mark['label'])] += 1
dicts = sorted(dicts.items(), key=lambda x: x[0])
for k, v in dicts:
print('%d个数字的图片数目: %d'%(k, v))
label_summary()
1个数字的图片数目: 4636
2个数字的图片数目: 16262
3个数字的图片数目: 7813
4个数字的图片数目: 1280
5个数字的图片数目: 8
6个数字的图片数目: 1
可以看到,只有一个图片包含数字为6个,可能是异常值,可以不予考虑。几乎全部1~4个数字的图片几乎占了训练图片的全部。
构建数据集
这里,我们借鉴Datawhale提供的Baseline, 由于每个图片最多只包含不到6个数字,为了简化,将字符识别当做一个分类问题来处理。
这里自定义数据集,DigitsDataset继承自torch.utils.data.Dataset,数据增强使用自带的torchvison.transforms。这里只进行了常规的增强操作,比如旋转,随机转灰度,随机调整HSV等。
class DigitsDataset(Dataset):
"""
DigitsDataset
Params:
data_dir(string): data directory
label_path(string): label path
aug(bool): wheather do image augmentation, default: True
"""
def __init__(self, data_dir, label_path, size=(64, 128), aug=True):
super(DigitsDataset, self).__init__()
self.imgs = glob(data_dir+'*.png')
self.aug = aug
self.size = size
if label_path == None:
self.labels = None
else:
self.labels = json.load(open(label_path, 'r'))
self.imgs = [(img, self.labels[os.path.split(img)[-1]]) for img in self.imgs if os.path.split(img)[-1] in self.labels]
def __getitem__(self, idx):
if self.labels:
img, label = self.imgs[idx]
else:
img = self.imgs[idx]
label = None
img = Image.open(img)
trans0 = [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
min_size = self.size[0] if (img.size[1] / self.size[0]) < ((img.size[0] / self.size[1])) else self.size[1]
trans1 = [
transforms.Resize(min_size),
transforms.CenterCrop(self.size)
]
if self.aug:
trans1.extend([
transforms.ColorJitter(0.1, 0.1, 0.1),
transforms.RandomGrayscale(0.1),
transforms.RandomAffine(10,translate=(0.05, 0.1), shear=5)
])
trans1.extend(trans0)
img = transforms.Compose(trans1)(img)
if self.labels:
return img, t.tensor(label['label'][:5] + (5 - len(label['label']))*[10]).long()
else:
return img, self.imgs[idx]
def __len__(self):
return len(self.imgs)
查看一下数据增强的效果
fig, ax = plt.subplots(figsize=(6, 12), nrows=4, ncols=2)
for i in range(8):
img, label = dataset[i]
# 这些需要进行逆标准化
img = img * t.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + t.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
ax[i//2, i%2].imshow(img.permute(1, 2, 0).numpy())
ax[i//2, i%2].set_xticks([])
ax[i//2, i%2].set_yticks([])
plt.show()
总结
这里主要介绍了数据的准备和数据集的构建,并未使用比较高级复杂的操作,目的是为了搭建一个基础的数据框架,后续可以更加方便的在此基础上增加其他的操作。