H5文件读取

17 篇文章 2 订阅
11 篇文章 1 订阅

H5文件读取:

import torch.utils.data as data
import torch
import h5py

class DatasetFromHdf5(data.Dataset):
    def __init__(self, file_path):
        super(DatasetFromHdf5, self).__init__()
        hf = h5py.File(file_path)
        self.data = hf.get('data')
        self.target = hf.get('label')


    def __getitem__(self, index):
        return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()


    def __len__(self):
        return self.data.shape[0]

调用的时候,先用DataLoader将数据装入 training_data_loader中

 train_set = DatasetFromHdf5(r"D:\PycharmProjects\pytorch-vdsr-master\data\train.h5")
 training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

在使用数据训练的时候写一个循环,iteration只是一个计数的,从1开始计数,表示已经取第iteration个批次了,batch就是每次取出一个批次的数值。

input和target是取出的输入和希望得到的输出,这里的返回顺序是在上边的DatasetFromHdf5中定义的。

 def __getitem__(self, index):
        return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()

所以batch[0]表示input(也就是存储的data),batch[1]表示label(也就是label)。
index在这里应该是每次按第一个维度取出data中的数值。data[index,:,:,:],本来是维度是1000×1×41×41,每次取的是1×1×41×41。按照batch来,每次取出的就是batch×1×41×41

  for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值