Flickr8k数据集处理(草履虫也能看得懂!)

1:下载数据集

Flickr8k:该数据集选取了图片社交网站flickr中总计8000幅关于人或动物某种行为的图像。每幅图像都对应5个人工标记的句子描述。数据集划分一般采用Karpathy提供的方法:6000幅图像和其对应的句子描述组成训练集,1000幅图像和描述为验证集,剩余1000幅图像和描述为测试集。

数据集下载地址:

图像:https://www.kaggle.com/datasets/adityajn105/flickr8k

Karpathy方法json文件:http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip

确保 dataset_flickr8k.json 文件和 image 文件夹在同一目录下。

注:可能会出现第二个链接(json文件)无法打开的问题。这个链接是个下载连接,而不是个网页的链接。点击这个链接后应该会直接跳转下载,如果电脑上有迅雷什么的应该会直接弹出来。


2:整理数据集

数据集下载完成后,需要对其进行处理,以适合之后构造的 Pytorch 数据集类进行读取。对于文本描述,首先构建词典,然后根据词典将文本描述转化为向量。对于图像,这里仅记录文件路径。如果机器的内存和硬盘空间足够,这里也可以将图片读取并处理为三维数组,这样在模型训练和测试阶段就不需要再读取图片。

dataset_flickr8k.json文件分析:

下面是整理数据集的函数代码。

导包:

import json
import os 
import random
from collections import Counter, defaultdict
from matplotlib import pyplot as plt
from PIL import Image

函数:

def create_dataset(dataset = 'flickr8k', captions_per_image = 5, min_word_count = 5, max_len = 30):
    """
    参数:
        dataset:数据集名称
        captions_per_image:每张图片对应的文本描述数
        min_word_count:仅考虑在数据集中(除测试集外)出现5次的词
        max_len:文本描述包含的最大单词数。如果文本描述超过该值则截断
    输出:
        一个词典文件:vocab.json
        三个数据集文件:train_data.json val_data.json test_data.json
    """
    
    karpathy_json_path = "./dataset_flickr8k.json"  # 读取json文件
    image_folder = "./images/"                      # 图片文件夹
    output_folder = "./data/%s" % dataset           # 保存处理结果的文件夹

    # 读取数据集文本描述的json文件
    with open(file=karpathy_json_path, mode="r") as j:
        data = json.load(fp=j)
    

    image_paths = defaultdict(list)                 # collections.defaultdict() 参考:https://zhuanlan.zhihu.com/p/345741967 ; https://blog.csdn.net/sinat_38682860/article/details/112878842
    image_captions = defaultdict(list)
    vocab = Counter()                               # collections.Counter() 主要功能:可以支持方便、快速的计数,将元素数量统计,然后计数并返回一个字典,键为元素,值为元素个数。 参考:https://blog.csdn.net/chl183/article/details/106956807

    for img in data["images"]:                      # 读取每张图片
        split = img["split"]                        # split:该图片文本描述的编号 len(spilt)==5
        captions = []                               
        for c in img["sentences"]:                  # 读取图片的文本描述

            # 更新词频,测试集在训练过程中未见到数据集
            if split != "test":                     # 只读取train/val
                vocab.update(c['tokens'])           # 更新词表 这里的c['tokens']是一个列表,在这里对这列表中每个元素,也就是每个词使其在词表中的出现个数加一 参考:https://blog.csdn.net/ljr_123/article/details/106209564 ; https://blog.csdn.net/ljr_123/article/details/106209564
            
            # 不统计超过最大长度限制的词
            if len(c["tokens"]) <= max_len:
                captions.append(c["tokens"])        # 如果每个句子的单词数都不大与max_len,则len(captions)+=5

        if len(captions) == 0:                      # 万一有超过的也得往下循环
            continue
        
        path =os.path.join(image_folder, img['filename'])    # 读取图片路径:"./images/img['filename']" 这里img['filename']为图片名字 os.path.join()函数用于路径拼接文件路径,可以传入多个路径 参考:https://blog.csdn.net/swan777/article/details/89040802

        image_paths[split].append(path)             # 保存每张图片路径
        image_captions[split].append(captions)      # 保存每张图片对应描述文本
    
    """
    执行完以上步骤后得到了:vocab, image_captions, image_paths

    vocab 为一个字典结构,key为各个出现的词; value为这个词出现的个数
    image_captions 为一个字典结构,key为"train","val"; val为列表,表中元素为一个个文本描述的列表
    image_paths 为一个字典结构,key为"train","val"; val为列表,表中元素为图片路径的字符串
    
    可运行以下代码验证:
    print(vocab)
    print(image_paths["train"][1])
    print(image_captions["train"][1])
    """
    
    # 创造词典,增加占位符<pad>,未登录词标识符<unk>,句子首尾标识符<start>和<end>
    words = [w for w in vocab.keys() if vocab[w]> min_word_count]
    vocab = {k : v + 1 for v, k in enumerate(words)}

    vocab['<pad>']=0
    vocab['<unk>']=len(vocab)
    vocab['<start>']=len(vocab)
    vocab['<end>']=len(vocab)
    
    # 储存词典
    with open(os.path.join(output_folder, 'vocab.json'),"w") as fw:
        json.dump(vocab,fw)
    

    # 整理数据集
    for split in image_paths:                       # 只会循环三次 split = "train" 、 split = "val" 和 split = "test"
        
        imgpaths = image_paths[split]               # type(imgpaths)=list
        imcaps = image_captions[split]              # type(imcaps)=list
        enc_captions = []
    
        for i, path in enumerate(imgpaths):

            # 合法性检测,检查图像时候可以被解析
            img = Image.open(path)                  # 参考:https://blog.csdn.net/weixin_43723625/article/details/108158375
            
            # 如果图像对应的描述数量不足,则补足
            if len(imcaps[i]) < captions_per_image:
                filled_num = captions_per_image - len(imcaps[i])
                captions = imcaps[i]+ [random.choice(imcaps[i]) for _ in range(0, filled_num)]
            else:
                captions = random.sample(imcaps[i],k=captions_per_image)        # 打乱文本描述 参考:https://blog.csdn.net/qq_37281522/article/details/85032470
            
            assert len(captions)==captions_per_image

            for j,c in enumerate(captions):
                # 对文本描述进行编码
                enc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]
                enc_captions.append(enc_c)
    
        assert len(imgpaths)* captions_per_image == len(enc_captions)

        data = {"IMAGES" : imgpaths,
                "CAPTIONS" : enc_captions}
        
        # 储存训练集,验证集,测试集
        with open(os.path.join(output_folder,split+"_data.json"),'w') as fw:
            json.dump(data, fw)

运行函数完成处理:

create_dataset()

看看效果:

with open('./data/flickr8k/vocab.json','r') as f:
    vocab =json.load(f)
vocab_idx2word = {idx : word for word, idx in vocab.items()}

with open('./data/flickr8k/test_data.json','r') as f:
    data = json.load(f)


content_img = Image.open(data['IMAGES'][300])
plt.imshow(content_img)

print(len(data))
print(len(data['IMAGES']))
print(len(data["CAPTIONS"]))

for i in range(5):
    word_indeces = data['CAPTIONS'][300*5+i]
    print(''.join([vocab_idx2word[idx] for idx in word_indeces]))

运行结果:

2
1000
5000
<start>peopleinorangerobeslineupbehindamanwearingsunglasses<end>
<start>three<unk>walkonthestreet<end>
<start>agroupofmenwearingyellow<unk>walkinaline<end>
<start>menwalkinlinecarryingthings<end>
<start>threemeninorangerobesholdingmetal<unk><end>


3:定义数据集类

在准备好的数据集的基础上,需要进一步定义 Pytorch Dataset 类,以使用 Pytorch Dataloder 类按批次产生数据。Pytorch 中仅预先定义了图像、文本和语音的单模态任务中常见的数据集,因此我们还是要对 Flickr8k 数据集进行处理。

在 Pytorch 中定义数据集十分简单,仅继承 torch.utils.data.Dataset 类,并实现 __getitem__ 和__len__ 两个函数即可。

导包:

from argparse import Namespace
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms

class类如下

class ImageTextDataset(Dataset):
    """
    Pytorch 数据类,用于 P有torch Dataloader 来按批次产生数据
    """
    def __init__(self, dataset_path, vocab_path, split, captions_per_image = 5, max_len = 30, transform = None):
        """
        参数:
            dataset_path: json 格式数据文件路径
            vocab_path: json 格式字典文件路径
            split: "tarin", "val", "test"
            captions_per_image: 每张图片对应的文本描述数
            max_len: 文本描述最大单词量
            transform: 图像预处理方法
        """
        self.split = split
        assert self.split in {"tarin", "val", "test"}       # assert的应用 参考:https://blog.csdn.net/TeFuirnever/article/details/88883859
        self.cpi = captions_per_image
        self.max_len = max_len

        # 载入图像
        with open(dataset_path,"r") as f:
            self.data = json.load(f)
        
        # 载入词典
        with open(vocab_path,"r") as f:
            self.vocab = json.load(f)
        
        # 图像预处理流程
        self.transform = transform

        # 数据量
        self.dataset_size = len(self.data["CAPTIONS"])
    
    def __getitem__(self, i):

        # 第 i 个样本描述对应第 (i // captions_per_image) 张图片
        img = Image.open(self.data['IMAGES'][i // self.cpi]).convert("RGB")     # 参考:https://blog.csdn.net/nienelong3319/article/details/105458742
        
        # 如歌有图像预处理流程,进行图像预处理
        if self.transform is not None:
            img = self.transform(img)
        
        caplen = len(self.data["CAPTIONS"][i])
        pad_caps = [self.vocab['<pad>']] * (self.max_len + 2 - caplen)
        caption = torch.LongTensor(self.data["CAPTIONS"][i] + pad_caps)         # 类型转换 参考:https://blog.csdn.net/qq_45138078/article/details/131557441

        return img, caption, caplen

    def __len__(self):
        return self.dataset_size

如果想提前对数据集进行处理可以设置 transform = transforms.XXX

具体可以参考:https://blog.csdn.net/qq_37555071/article/details/107532319

在这里就设为None,在下一步中再进行处理


4:批量读取数据

利用刚才构造的数据集类,借助 Dataloder 类构建能够批量产生训练、验证、测试数据对象。

函数:

def mktrainval(data_dir, vocab_path, batch_size, workers = 4):
    train_tx = transforms.Compose([
        transforms.Resize(256),                                                     # 缩放
        transforms.RandomCrop(224),                                                 # 随机裁剪
        transforms.ToTensor(),                                                      # 用于对载入的图片数据进行类型转换,将图片数据转换成Tensor数据类型的变量
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化,这里的均值和方差为在ImageNet数据集上抽样计算出来的
    ])

    val_tx = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),                                                 # 中心裁剪
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_set = ImageTextDataset(dataset_path = os.path.join(data_dir,"train_data.json"), vocab_path = vocab_path, split = "train", transform = train_tx)
    vaild_set = ImageTextDataset(dataset_path = os.path.join(data_dir,"val_data.json"), vocab_path = vocab_path, split = "val", transform = val_tx)
    test_set = ImageTextDataset(dataset_path = os.path.join(data_dir,"test_data.json"), vocab_path = vocab_path, split = "test", transform = val_tx)

    train_loder = data.DataLoader(              
        dataset = train_set, batch_size = batch_size, shuffer = True,
        num_workers = workers, pin_memory = True
    )                                   # 参考:https://blog.csdn.net/rocketeerLi/article/details/90523649 ; https://blog.csdn.net/zfhsfdhdfajhsr/article/details/116836851

    val_loder = data.DataLoader(              
        dataset = vaild_set, batch_size = batch_size, shuffer = False,
        num_workers = workers, pin_memory = True, drop_last=False
    )                                   # 验证集和测试集不需要打乱数据顺序:shuffer = False

    test_loder = data.DataLoader(              
        dataset = test_set, batch_size = batch_size, shuffer = False,
        num_workers = workers, pin_memory = True, drop_last=False
    )                                   

    return train_loder, val_loder, test_loder

至此,数据集的处理就已经完成了。train_loder 为训练集,val_loder 为验证集,test_loder 为测试集。

参考:

        《多模态深度学习技术基础》冯方向 王小捷

评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值