一般数据读取会涉及两个函数:
1.get_dataset(data_path,cfg)。其中的__getitem__(self,idx)函数,这个函数用于读取给定路径下的数据并处理成所需的格式,返回一个数据的data,data一般为字典格式。作用:规范数据,path-->tensor(data)。
2.dataloader(get_dataset),可直接调用pytorch中的接口。通过
for i ,data_batch in enumerate(dataloader)
调用前文提到的__getitem__,返回batch_size个data。之后按照pytirch中dataload函数,将多个data进一步处理,返回data_batch。
##先初始化model的网络结构,再将model定义为数据并行的状态
model = DataParallel(model.cuda(),device_ids=[0,1,2])
##训练数据前向传播
model(data_batch)
之后在使用model(data_batch)时,数据便会加载至model所在的gpu上。
pytorch中的DataLoader
def collate(data_batch, samples_per_gpu=1, num_gpus=1):