在已经加入model.eval()
情况下仍然没有解决,基本代码如下,因为需要计算test数据整体的评价结果,所以需要将probs存在list中,这会导致显存持续增加,故将probs等变量detach,如此测试时显存不再递增
model.eval()
probs, labels, losses = [], [], []
for batch in self.batch_generator():
x, label = batch
scores = model(x)
loss = criterion(scores, torch.LongTensor(label).cuda())
probs.append(scores.cpu().detach())
labels.append(torch.LongTensor(label).cpu().detach())
losses.append(loss.item())
probs = torch.cat(probs, 0)
labels = torch.cat(labels, 0)