DataWhale街景字符编码识别项目-数据准备

转自:https://segmentfault.com/a/1190000022771088

查看数据

在构建数据集之前,我们先对数据进行一些可视化,对数据有一个大致的了解。
文件路径如下, 将其保存为字典

data_dir = {
    'train_data': '/content/data/mchar_train/',
    'val_data': '/content/data/mchar_val/',
    'test_data': '/content/data/mchar_test_a/',
    'train_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_train.json',
    'val_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_val.json',
    'submit_file': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_sample_submit_A.csv'
}

查看图片数量

def data_summary():
    train_list = glob(data_dir['train_data']+'*.png')
    test_list = glob(data_dir['test_data']+'*.png')
    val_list = glob(data_dir['val_data']+'*.png')
    print('train image counts: %d'%len(train_list))
    print('val image counts: %d'%len(val_list))
    print('test image counts: %d'%len(test_list))

data_summary()
train image counts: 30000
val image counts: 10000
test image counts: 40000

查看标注文件信息

def look_train_json():
    with open(data_dir['train_label'], 'r', encoding='utf-8') as f:
        content = f.read()
    # loads将字符串转为字典
    content = json.loads(content)

    print(content['000000.png'])

look_train_json()
{'height': [219, 219], 'label': [1, 9], 'left': [246, 323], 'top': [77, 81], 'width': [81, 96]}

查看结果文件提交格式

def look_submit():
    df = pd.read_csv(data_dir['submit_file'], sep=',')
    print(df.head(5))

look_submit()
    file_name  file_code
0  000000.png          0
1  000001.png          0
2  000002.png          0
3  000003.png          0
4  000004.png          0

在图片上查看标注框

def plot_samples():
    imgs = glob(data_dir['train_data']+'*.png')
    fig, ax = plt.subplots(figsize=(12, 8), ncols=2, nrows=2)
    marks = json.loads(open(data_dir['train_label'], 'r').read())
    for i in range(4):

        img_name = os.path.split(imgs[i])[-1]
        mark = marks[img_name]
        img = Image.open(imgs[i])
        img = np.array(img)

        bboxes = np.array(
            [mark['left'],
            mark['top'],
            mark['width'],
            mark['height']]
        )
        ax[i//2, i%2].imshow(img)
        for j in range(len(mark['label'])):
        
        # 定义一个矩形
        rect = patch.Rectangle(bboxes[:, j][:2], bboxes[:, j][2], bboxes[:, j][3], facecolor='none', edgecolor='r')
        ax[i//2, i%2].text(bboxes[:, j][0], bboxes[:, j][1], mark['label'][j])
        # 绘制矩形
        ax[i//2, i%2].add_patch(rect)
    plt.show()

plot_samples()

image.png

查看训练图片的长宽分布

def img_size_summary():
    sizes = []

    for img in glob(data_dir['train_data']+'*.png'):
        img = Image.open(img)

        sizes.append(img.size)

    sizes = np.array(sizes)

    plt.figure(figsize=(10, 8))
    plt.scatter(sizes[:, 0], sizes[:, 1])
    plt.xlabel('Width')
    plt.ylabel('Height')

    plt.title('image width-height summary')
    plt.show()
    return np.mean(sizes, axis=0), np.median(sizes, axis=0)

mean, median = img_size_summary()
print('mean: ', mean)
print('median: ', median)

image.png

可以看到,训练图片之间的尺寸差异非常大,且基本上都是宽要比高大,宽之间的差异大于高之间的差异。后续确定网络输入大小,可以结合中位数或平均值确定网络输入大小。

查看边界框大小分布

def bbox_summary():
    marks = json.loads(open(data_dir['train_label'], 'r').read())
    bboxes = []

    for img, mark in marks.items():
        for i in range(len(mark['label'])):
        bboxes.append([mark['left'][i], mark['top'][i], mark['width'][i], mark['height'][i]])

    bboxes = np.array(bboxes)

    fig, ax = plt.subplots(figsize=(12, 8))
    ax.scatter(bboxes[:, 2], bboxes[:, 3])
    ax.set_title('bbox width-height summary')
    ax.set_xlabel('width')
    ax.set_ylabel('height')
    plt.show()

bbox_summary()

image.png

如果采用目标检测的思路实现字符识别,可以使用Kmeans聚类的方式来对边界框来确定anchor尺寸。

查看不同字符类别的数目

def label_nums_summary():
    marks = json.load(open(data_dir['train_label'], 'r'))

    dicts = {i: 0 for i in range(10)}
    for img, mark in marks.items():
        for lb in mark['label']:
        dicts[lb] += 1

    xticks = list(range(10))
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.bar(x=list(dicts.keys()), height=list(dicts.values()))
    ax.set_xticks(xticks)
    plt.show()
    return dicts

print(label_nums_summary())

image.png

可以看出不同类别之间差别总体差异不大,除了数字1的出现次数较大。没有出现极端不平衡的情况。后期分类可以考虑使用Weighted-CrossEntropy损失。

查看每个图片出现的数字个数

def label_summary():
    marks = json.load(open(data_dir['train_label'], 'r'))

    dicts = {}
    for img, mark in marks.items():
        if len(mark['label']) not in dicts:
        dicts[len(mark['label'])] = 0
        dicts[len(mark['label'])] += 1
    dicts = sorted(dicts.items(), key=lambda x: x[0])
    for k, v in dicts:
        print('%d个数字的图片数目: %d'%(k, v))

label_summary()

1个数字的图片数目: 4636
2个数字的图片数目: 16262
3个数字的图片数目: 7813
4个数字的图片数目: 1280
5个数字的图片数目: 8
6个数字的图片数目: 1

可以看到,只有一个图片包含数字为6个,可能是异常值,可以不予考虑。几乎全部1~4个数字的图片几乎占了训练图片的全部。

构建数据集

这里,我们借鉴Datawhale提供的Baseline, 由于每个图片最多只包含不到6个数字,为了简化,将字符识别当做一个分类问题来处理。

这里自定义数据集,DigitsDataset继承自torch.utils.data.Dataset,数据增强使用自带的torchvison.transforms。这里只进行了常规的增强操作,比如旋转,随机转灰度,随机调整HSV等。

class DigitsDataset(Dataset):
    """
    
    DigitsDataset
    
    Params:
        data_dir(string): data directory
    
        label_path(string): label path
    
        aug(bool): wheather do image augmentation, default: True
    """
    def __init__(self, data_dir, label_path, size=(64, 128), aug=True):
        super(DigitsDataset, self).__init__()
        self.imgs = glob(data_dir+'*.png')
    
        self.aug = aug
    
        self.size = size
        if label_path == None:
            self.labels = None
        else:
            self.labels = json.load(open(label_path, 'r'))
            self.imgs = [(img, self.labels[os.path.split(img)[-1]]) for img in self.imgs if os.path.split(img)[-1] in self.labels]
        
    def __getitem__(self, idx):
        if self.labels:
            img, label = self.imgs[idx]
        else:
            img = self.imgs[idx]
            label = None
        
            img = Image.open(img)
        
            trans0 = [                
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
            
            min_size = self.size[0] if (img.size[1] / self.size[0]) < ((img.size[0] / self.size[1])) else self.size[1]
            trans1 = [
                transforms.Resize(min_size),    
                transforms.CenterCrop(self.size)
                ]
    
        if self.aug:
            trans1.extend([
                    transforms.ColorJitter(0.1, 0.1, 0.1),
                    transforms.RandomGrayscale(0.1),
                    transforms.RandomAffine(10,translate=(0.05, 0.1), shear=5)
            ])
    
        trans1.extend(trans0)
        
        img = transforms.Compose(trans1)(img)
    
        if self.labels:
            return img, t.tensor(label['label'][:5] + (5 - len(label['label']))*[10]).long()
        else:
            return img, self.imgs[idx]
    
    
    def __len__(self):
        return len(self.imgs)

查看一下数据增强的效果

fig, ax = plt.subplots(figsize=(6, 12), nrows=4, ncols=2)
for i in range(8):
    img, label = dataset[i]
    # 这些需要进行逆标准化
    img = img * t.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + t.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    ax[i//2, i%2].imshow(img.permute(1, 2, 0).numpy())
    
    ax[i//2, i%2].set_xticks([])
    ax[i//2, i%2].set_yticks([])

plt.show()

image.png

总结

这里主要介绍了数据的准备和数据集的构建,并未使用比较高级复杂的操作,目的是为了搭建一个基础的数据框架,后续可以更加方便的在此基础上增加其他的操作。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值