slake数据集数据预处理

* coding: utf-8 *

“”"
@Time : 2021/9/23 9:29
@Author : smile 笑
@File : datasets_text.py
@desc :
“”"

import pickle
from torch.utils.data import DataLoader, Dataset
from word_sequence import Word2Sequence, SaveWord2Vec
import torchvision.transforms as tfs
import os
from PIL import Image
import json
import torch
import numpy as np
import toml
import argparse
from main import train_configure, test_configure, Sort2Id
from word_sequence import sentence_to_word
from graph.entity_id import EntityId
from graph.graph_extraction import GraphInformationExtraction
from graph.entity_id import EntityId

def train_aug_img(img, configure):
aug = tfs.Compose([
tfs.RandomResizedCrop(configure[“image”][“img_height”], scale=(configure[“image”][“resized_crop_left”], configure[“image”][“resized_crop_right”])),
tfs.RandomApply([tfs.GaussianBlur(kernel_size=configure[“image”][“b_size”], sigma=configure[“image”][“blur”])], p=configure[“image”][“blur_p”]),
tfs.RandomGrayscale(p=configure[“image”][“grayscale”]),
tfs.RandomApply([
tfs.ColorJitter(configure[“image”][“brightness”], configure[“image”][“contrast”], configure[“image”][“saturation”], configure[“image”][“hue”])],
p=configure[“image”][“apply_p”]
),
tfs.RandomRotation(configure[“image”][“img_rotation”]),
tfs.RandomHorizontalFlip(configure[“image”][“img_flip”]),
tfs.ToTensor(),
tfs.Normalize(configure[“image”][“img_mean”], configure[“image”][“img_std”])
])

return aug(img)

def test_aug_img(img, configure):
aug = tfs.Compose([
tfs.Resize([configure[“image”][“img_height”], configure[“image”][“img_width”]]),
tfs.ToTensor(),
tfs.Normalize(configure[“image”][“img_mean”], configure[“image”][“img_std”])
])

return aug(img)

class VQADataset(Dataset):
def init(self, configure):
self.configure = configure
self.xm_path = configure[“dataset”][“dataset_xm_path”]
self.run_mode = configure[“run_mode”] # 得到运行方式
self.graph_flag = configure[“graph_flag”] # 得到是否使用知识图谱

    self.queries = json.load(open(configure["dataset"]["dataset_path"], encoding="utf-8"))

    if configure["lang_choice"] == "en":
        self.queries = [query for query in self.queries if query["q_lang"] == "en"]  # 4919、1061
        self.qus_ws = pickle.load(open(configure["config"]["en_qus_ws_path"], "rb"))
        self.ans_ws = pickle.load(open(configure["config"]["en_ans_ws_path"], "rb"))
        self.max_seq_len = self.configure["config"]["en_qus_seq_len"]

        if self.run_mode == "train" and self.graph_flag:
            self.graph_extraction = GraphInformationExtraction(True, configure["graph"]["graph_seq_len"])

        if self.run_mode == "train_test" and self.graph_flag:
            self.graph_extraction = GraphInformationExtraction(True, configure["graph"]["graph_seq_len"])

        if self.run_mode == "test" and self.graph_flag:
            self.graph_extraction = GraphInformationExtraction(False, configure["graph"]["graph_seq_len"])

    else:
        self.queries = [query for query in self.queries if query["q_lang"] == "zh"]  # train:4916、test:1033
        self.qus_ws = pickle.load(open(configure["config"]["zh_qus_ws_path"], "rb"))
        self.ans_ws = pickle.load(open(configure["config"]["zh_ans_ws_path"], "rb"))
        self.max_seq_len = self.configure["config"]["zh_qus_seq_len"]

    self.sort_ws = pickle.load(open(configure["config"]["sort_ws_path"], "rb"))

def __getitem__(self, idx):
    query = self.queries[idx]  # 随机抽取一个

    img_path = os.path.join(self.xm_path + str(query["img_id"]), "source.jpg")

    graph_id_seq = torch.tensor([0])

    if self.configure["lang_choice"] == "en":
        question = sentence_to_word(query["question"], True)
        answer = sentence_to_word(query["answer"], False)
        if self.graph_flag:
            graph_id_seq = self.graph_extraction(query["question"])

    else:
        question = [i for i in query["question"]]  # 中文的分割方式
        answer = query["answer"]

    if self.run_mode == "train":
        image = train_aug_img(Image.open(img_path).convert("RGB"), self.configure)
    else:
        image = test_aug_img(Image.open(img_path).convert("RGB"), self.configure)

    location = query["location"]
    location_id = self.sort_ws.sort_to_id(location)

    answer_type = query["answer_type"]
    if answer_type == "OPEN":
        ans_type_id = self.configure["config"]["answer_open"]
    else:
        ans_type_id = self.configure["config"]["answer_close"]

    qus_id = self.qus_ws.transform(question, max_len=self.max_seq_len)
    ans_id = self.ans_ws.transform([answer])

    return image, qus_id, ans_id, location_id, ans_type_id, graph_id_seq

def __len__(self):
    return len(self.queries)

def collate_fn(batch):
image, qus_id, ans_id, location_id, ans_type_id, graph_id_seq = list(zip(*batch))
image = torch.stack(image).cuda(train_configure[“config”][“device_ids”][0])
qus_id = torch.tensor(qus_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])
ans_id = torch.tensor(ans_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])
location_id = torch.tensor(location_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])
ans_type_id = torch.tensor(ans_type_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])

graph_id_seq = torch.tensor(graph_id_seq).cuda(train_configure["config"]["device_ids"][0])

return image, qus_id, ans_id, location_id, ans_type_id, graph_id_seq

def get_dataloader(configure):
db = VQADataset(configure)
dl = DataLoader(db, batch_size=configure[“config”][“batch_size”], shuffle=configure[“config”][“shuffle”], collate_fn=collate_fn)

return dl

if name == ‘main’:
parser = argparse.ArgumentParser()
parser.add_argument("–config_path", default="./config/train.toml")
args = parser.parse_args()
config = toml.load(args.config_path)

for idx, (img, qus, ans, location, ans_type, entity_id_seq) in enumerate(get_dataloader(config)):
    print(img.shape)
    print(qus.shape)
    print(ans.shape)
    print(location.shape)
    print(ans_type.shape)
    print(entity_id_seq)
    break
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值