更新中:
case1:
原代码(目标检测train),报错的地方在multibel_loss
def __call__(self):
for epoch in range(1000):
for i, (img, label, position, multilabel) in enumerate(self.train_dataloader):
# print(img.shape, label.shape, position.shape, multilabel.shape) #shape与batchsize有关,多试几次就知道了
self.net.train()
img,label,position,multilabel=img.to(DEVICE),label.to(DEVICE),position.to(DEVICE),multilabel.to(DEVICE) #数据放到GPU上面
out_label,out_position,out_multilabel=self.net(img)
multilabel_loss=self.multilabel_loss(out_multilabel,multilabel)
print()
改为下面的代码后解决报错:
def __call__(self):
for epoch in range(1000):
for i, (img, label, position, multilabel) in enumerate(self.train_dataloader):
# print(img.shape, label.shape, position.shape, multilabel.shape) #shape与batchsize有关,多试几次就知道了
self.net.train()
img,label,position,multilabel=img.to(DEVICE),label.to(DEVICE),position.to(DEVICE),multilabel.to(DEVICE) #数据放到GPU上面
out_label,out_position,out_multilabel=self.net(img)
out_multilabel = out_multilabel[torch.where(multilabel >= 0)]
multilabel=multilabel[torch.where(multilabel>=0)]
multilabel_loss=self.multilabel_loss(out_multilabel,multilabel)
print()