Nvidia官方数据预读取程序,给神经网络训练提提速

本文介绍了Nvidia官方提供的数据预读取程序,该程序旨在提升神经网络训练速度。通过训练模式下的预加载方法,可以有效提高数据处理效率,加快深度学习模型的训练过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

实现方式

这里给出训练模式下的预加载方法,测试模式基本相同

def train_net(epoch, model, data_trainer, criterion, optimizer):
    model.train()
    prefetcher = data_prefetcher(data_trainer, test=False) #实例化data_prefetcher类
    data, label = prefetcher.next()
    batch_idx = 0
    while data is not None:
        batch_idx += 1
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.squeeze(1), label)
        train_loss += loss.item()
        loss.backward()
        #torch.nn.utils.clip_grad_value_(model.parameters(), 15)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20)
        optimizer.step()
        # scheduler.step()
        running_loss += loss.item()
        if  batch_idx  % 50 == 0:
            		打印损失
            running_loss = 0.0
        data, label = prefetcher.next() #加载下一组数据

数据预加载类

class data_prefetcher():
    def __init__(self, loader, test=False):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.test_flag = test
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_data, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            if self.test_flag == True: #这里可以不要test_flag,因为我的测试数据没有name,训练数据有name,所以要区分是在测试还是在训练
                self.next_input = self.next_data[:,:-1].cuda(non_blocking=True)
                self.next_target = self.next_target.cuda(non_blocking=True)
                self.next_name = self.next_data[:,-1].cuda(non_blocking=True)
            else:
                self.next_input = self.next_data.cuda(non_blocking=True)
                self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.reshape(-1,1,3738).float()
            #self.next_input = self.next_input.reshape(-1, 42, 89).float()
            self.next_target = self.next_target.float()#long用于分类。float用于预测
            if self.test_flag:
                self.next_name = self.next_name.to('cpu').numpy().reshape(1,-1)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        input = self.next_input
        target = self.next_target
        if self.test_flag:
            name = self.next_name
        self.preload()
        if self.test_flag == True:
            return input, target, name
        else:
            return input, target
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值