* coding: utf-8 *
“”"
@Time : 2021/9/23 9:29
@Author : smile 笑
@File : datasets_text.py
@desc :
“”"
import pickle
from torch.utils.data import DataLoader, Dataset
from word_sequence import Word2Sequence, SaveWord2Vec
import torchvision.transforms as tfs
import os
from PIL import Image
import json
import torch
import numpy as np
import toml
import argparse
from main import train_configure, test_configure, Sort2Id
from word_sequence import sentence_to_word
from graph.entity_id import EntityId
from graph.graph_extraction import GraphInformationExtraction
from graph.entity_id import EntityId
def train_aug_img(img, configure):
aug = tfs.Compose([
tfs.RandomResizedCrop(configure[“image”][“img_height”], scale=(configure[“image”][“resized_crop_left”], configure[“image”][“resized_crop_right”])),
tfs.RandomApply([tfs.GaussianBlur(kernel_size=configure[“image”][“b_size”], sigma=configure[“image”][“blur”])], p=configure[“image”][“blur_p”]),
tfs.RandomGrayscale(p=configure[“image”][“grayscale”]),
tfs.RandomApply([
tfs.ColorJitter(configure[“image”][“brightness”], configure[“image”][“contrast”], configure[“image”][“saturation”], configure[“image”][“hue”])],
p=configure[“image”][“apply_p”]
),
tfs.RandomRotation(configure[“image”][“img_rotation”]),
tfs.RandomHorizontalFlip(configure[“image”][“img_flip”]),
tfs.ToTensor(),
tfs.Normalize(configure[“image”][“img_mean”], configure[“image”][“img_std”])
])
return aug(img)
def test_aug_img(img, configure):
aug = tfs.Compose([
tfs.Resize([configure[“image”][“img_height”], configure[“image”][“img_width”]]),
tfs.ToTensor(),
tfs.Normalize(configure[“image”][“img_mean”], configure[“image”][“img_std”])
])
return aug(img)
class VQADataset(Dataset):
def init(self, configure):
self.configure = configure
self.xm_path = configure[“dataset”][“dataset_xm_path”]
self.run_mode = configure[“run_mode”] # 得到运行方式
self.graph_flag = configure[“graph_flag”] # 得到是否使用知识图谱
self.queries = json.load(open(configure["dataset"]["dataset_path"], encoding="utf-8"))
if configure["lang_choice"] == "en":
self.queries = [query for query in self.queries if query["q_lang"] == "en"] # 4919、1061
self.qus_ws = pickle.load(open(configure["config"]["en_qus_ws_path"], "rb"))
self.ans_ws = pickle.load(open(configure["config"]["en_ans_ws_path"], "rb"))
self.max_seq_len = self.configure["config"]["en_qus_seq_len"]
if self.run_mode == "train" and self.graph_flag:
self.graph_extraction = GraphInformationExtraction(True, configure["graph"]["graph_seq_len"])
if self.run_mode == "train_test" and self.graph_flag:
self.graph_extraction = GraphInformationExtraction(True, configure["graph"]["graph_seq_len"])
if self.run_mode == "test" and self.graph_flag:
self.graph_extraction = GraphInformationExtraction(False, configure["graph"]["graph_seq_len"])
else:
self.queries = [query for query in self.queries if query["q_lang"] == "zh"] # train:4916、test:1033
self.qus_ws = pickle.load(open(configure["config"]["zh_qus_ws_path"], "rb"))
self.ans_ws = pickle.load(open(configure["config"]["zh_ans_ws_path"], "rb"))
self.max_seq_len = self.configure["config"]["zh_qus_seq_len"]
self.sort_ws = pickle.load(open(configure["config"]["sort_ws_path"], "rb"))
def __getitem__(self, idx):
query = self.queries[idx] # 随机抽取一个
img_path = os.path.join(self.xm_path + str(query["img_id"]), "source.jpg")
graph_id_seq = torch.tensor([0])
if self.configure["lang_choice"] == "en":
question = sentence_to_word(query["question"], True)
answer = sentence_to_word(query["answer"], False)
if self.graph_flag:
graph_id_seq = self.graph_extraction(query["question"])
else:
question = [i for i in query["question"]] # 中文的分割方式
answer = query["answer"]
if self.run_mode == "train":
image = train_aug_img(Image.open(img_path).convert("RGB"), self.configure)
else:
image = test_aug_img(Image.open(img_path).convert("RGB"), self.configure)
location = query["location"]
location_id = self.sort_ws.sort_to_id(location)
answer_type = query["answer_type"]
if answer_type == "OPEN":
ans_type_id = self.configure["config"]["answer_open"]
else:
ans_type_id = self.configure["config"]["answer_close"]
qus_id = self.qus_ws.transform(question, max_len=self.max_seq_len)
ans_id = self.ans_ws.transform([answer])
return image, qus_id, ans_id, location_id, ans_type_id, graph_id_seq
def __len__(self):
return len(self.queries)
def collate_fn(batch):
image, qus_id, ans_id, location_id, ans_type_id, graph_id_seq = list(zip(*batch))
image = torch.stack(image).cuda(train_configure[“config”][“device_ids”][0])
qus_id = torch.tensor(qus_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])
ans_id = torch.tensor(ans_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])
location_id = torch.tensor(location_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])
ans_type_id = torch.tensor(ans_type_id, dtype=torch.int64).cuda(train_configure[“config”][“device_ids”][0])
graph_id_seq = torch.tensor(graph_id_seq).cuda(train_configure["config"]["device_ids"][0])
return image, qus_id, ans_id, location_id, ans_type_id, graph_id_seq
def get_dataloader(configure):
db = VQADataset(configure)
dl = DataLoader(db, batch_size=configure[“config”][“batch_size”], shuffle=configure[“config”][“shuffle”], collate_fn=collate_fn)
return dl
if name == ‘main’:
parser = argparse.ArgumentParser()
parser.add_argument("–config_path", default="./config/train.toml")
args = parser.parse_args()
config = toml.load(args.config_path)
for idx, (img, qus, ans, location, ans_type, entity_id_seq) in enumerate(get_dataloader(config)):
print(img.shape)
print(qus.shape)
print(ans.shape)
print(location.shape)
print(ans_type.shape)
print(entity_id_seq)
break