论文解读
论文地址:https://arxiv.org/abs/1806.08317
数据集划分
数目 | train | val | test |
---|---|---|---|
293, 008 | 260, 480 | 32, 528 | 32, 528 |
类别介绍
数据集中有48个主类,121个子类。
如下是训练集、测试集中类别占比
图片统计
如下是训练集中主类、子类的图片数目统计
文本描述
如下是文本描述长度的统计
如下是从文本中提取的颜色分布
chanllenge
- Generating high-resolution images using P-GANs
- Text-to-Image synthesis
评估方法
- Inception Score
- Human Evaluation(因为Inception Score没考虑文本图片之间的相关性)
数据集下载
看到FashionBERT论文里的数据集FashionGEN,想了解一下,但是官网上填了个表单就没信了,地址为:https://fashion-gen.com/于是又在网上找了相关内容,找到一个网址https://github.com/menardai/FashionGenAttnGAN
上面有3个文件(注:没有提供测试集,论文中说不会提供测试集,被集成在了论文的docker中)
- fashiongen_256_256_train.h5
- fashiongen_256_256_validation.h5
- fashiongen_consume_data_example.pdf
分析代码
参考https://docs.h5py.org/en/stable/quick.html用以下代码进行分析
import h5py
import numpy as np
BATCH_SIZE = 32
def get_batch(file_h5, features, batch_number, batch_size=32):
"""Get a batch of the dataset
Args:
file_h5(str): path of the dataset
features(list(str)): list of names of features present in the dataset
that should be returned.
batch_number(int): the id of the batch to be returned.
batch_size(int): the mini-batch size
Returns:
A list of numpy arrays of the requested features"""
list_of_arrays = []
lb, ub = batch_number * batch_size, (batch_number + 1) * batch_size
for feature in features:
list_of_arrays.append(file_h5[feature][lb: ub])
return list_of_arrays
# open the file
# file_h5 = h5py.File('fashiongen_256_256_train.h5', mode='r')
file_h5 = h5py.File('fashiongen_256_256_validation.h5', mode='r')
# define the features to be retrieved
list_of_features = ['input_image', 'input_description']
dataset_len = len(file_h5['input_image'])
nb_batches = int(dataset_len / BATCH_SIZE)
batch_nb = np.random.randint(0, nb_batches)
# get the first batch of the data
list_of_arrays = get_batch(file_h5, list_of_features, batch_nb, BATCH_SIZE)
# close the file
file_h5.close()
得到训练集数目260490、验证集数目32528
数据集是个类似dict的结构,keys分别为
[‘index’, ‘index_2’, ‘input_brand’, ‘input_category’, ‘input_composition’, ‘input_concat_description’, ‘input_department’, ‘input_description’, ‘input_gender’, ‘input_image’, ‘input_msrpUSD’, ‘input_name’, ‘input_pose’, ‘input_productID’, ‘input_season’, ‘input_subcategory’]
图片的维度为:
(256, 256, 3)
内容分析
以验证集为例,接下来一个一个分析内容
- index
file_h5['index'].shape
# (32528, 1)
file_h5['index'][0:][0:]
# 输出以下
[[ 24]
[ 25]
[ 26]
...
[342153]
[342154]
[342155]]
- index_2
file_h5['index_2'