自学pytorch生成对抗网络编程一书源码标注
- 生成文件的方法
hdf5_file = './celeba_aligned_small.h5py'
total_images = 20000
with h5py.File(hdf5_file, 'w') as hf: # 打开h5py文件,文件不存在则会创建文件
count = 0
with zipfile.ZipFile('celeba/img_align_celeba.zip', 'r') as zf:
# 这个压缩文件里是一个文件夹img_align_celeba文件夹中有200000多张图片
for i in zf.namelist(): # zf.namelist()返还压缩文件中的文件列表名
# zf.namelist()[0]是'img_align_celeba/' 即文件夹路径
# zf.namelist()[1]是'img_align_celeba/000001.jpg' 即文件夹下的文件路径
if i[-4:] == '.jpg':
ofile = zf.extract(i) # 解压单个文件至ofile中
# 默认解压在当前文件夹即在'./'路径下创建img_align_celeba文件夹,把图片(i)放入文件夹中
# ofile是解压后图片(i)的相对地址是一个字符串
img = imageio.imread(ofile)
# 使用imageio.imread读取图片,此时img打印出来是一个数组
os.remove(ofile) # 用完即弃
# 删除图片,不占存储空间
hf.create_dataset('img_align_celeba/'+str(count)+'.jpg',
data=img, compression='gzip', compression_opts=9)
# compression是压缩方式, compression_opts是压缩程度的参数
# 在celeba_aligned_small.h5py文件中生成组img_align_celeba,在组中保存img数组
count = count + 1
if count % 1000 == 0:
print('images done ...', count)
if count == total_images: # 只取前20000张图片
break
- 使用文件
with h5py.File('./celeba_aligned_small.h5py', 'r') as file_object:
# h5py文件只读打开创建文件对象file_object,此对象为可遍历对象
for group in file_object: # 对文件对象进行遍历得到组名称的字符串
print(group) # 'img_align_celeba'
# 从群组中导出数据集,并以索引的形式展示
with h5py.File('./celeba_aligned_small.h5py', 'r') as file_object:
dataset = file_object['img_align_celeba'] # dataset即为h5py文件中的一个组
# 对文件对象用组名称进行索引得到组对象
image = np.array(dataset['8.jpg']) # 在组中以文件名索引具体文件
# 组对象用组中的文件名进行索引得到其中的数组,使用numpy转化为numpy数组
print(image.shape)
plt.imshow(image, interpolation='none')
plt.show()
# 这就是从h5py中取出numpy图片张量的方法
- 利用h5py生成自定义Dataset类
from torch.utils.data import Dataset
class CelebADataset(Dataset): # 继承pytorch的Dataset类
def __init__(self, file):
self.file_object = h5py.File(file, 'r') # 文件对象
self.dataset = self.file_object['img_align_celeba'] # 组对象
def __len__(self):
return len(self.dataset)
def __getitem__(self, item): # 从组对象中提取图片输出归一化的torch张量,图片格式为(h,w,3)
if item >= len(self.dataset): # 索引值大于等于长度报错
raise IndexError()
img = np.array(self.dataset[str(item)+'.jpg'])
return torch.cuda.FloatTensor(img) / 255.
# 至此可以实现单张图片处理
dataset = CelebADataset(hdf5_file) # 实例化
print(dataset[0]) # 使用索引直接调用__getitem__函数获得torch张量(h, w, 3)
# 进行批处理
train_iter = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=32)
for i in train_iter:
# 小bug如果声明torch.set_default_tensor_type()
# 这里会报错
print(i.shape) # torch.Size([32, 218, 178, 3])
break
- 使用h5py的优势
图片不会全部加载到内存当中,节约内存开支,在训练时对数据即取即用,相较于普通的在文件夹中存储图片数据的方式,将图片文件压缩进入h5py文件中,即取即用的效率更高