实现方式
这里给出训练模式下的预加载方法,测试模式基本相同
def train_net(epoch, model, data_trainer, criterion, optimizer):
model.train()
prefetcher = data_prefetcher(data_trainer, test=False) #实例化data_prefetcher类
data, label = prefetcher.next()
batch_idx = 0
while data is not None:
batch_idx += 1
optimizer.zero_grad()
output = model(data)
loss = criterion(output.squeeze(1), label)
train_loss += loss.item()
loss.backward()
#torch.nn.utils.clip_grad_value_(model.parameters(), 15)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20)
optimizer.step()
# scheduler.step()
running_loss += loss.item()
if batch_idx % 50 == 0:
打印损失
running_loss = 0.0
data, label = prefetcher.next() #加载下一组数据
数据预加载类
class data_prefetcher():
def __init__(self, loader, test=False):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.test_flag = test
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
self.next_data, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
if self.test_flag == True: #这里可以不要test_flag,因为我的测试数据没有name,训练数据有name,所以要区分是在测试还是在训练
self.next_input = self.next_data[:,:-1].cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_name = self.next_data[:,-1].cuda(non_blocking=True)
else:
self.next_input = self.next_data.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.reshape(-1,1,3738).float()
#self.next_input = self.next_input.reshape(-1, 42, 89).float()
self.next_target = self.next_target.float()#long用于分类。float用于预测
if self.test_flag:
self.next_name = self.next_name.to('cpu').numpy().reshape(1,-1)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if self.test_flag:
name = self.next_name
self.preload()
if self.test_flag == True:
return input, target, name
else:
return input, target