解决方法:数据集的labels改成txt格式就好了,很长一段时间没跑代码给忘记了,跟之前的数据集对比了一下区别才发现是这个问题……
2024.5.20更新:
有朋友跟我这个情况确实是不一样的,在这里再提供个思路吧。
关于torch.cat(),官网给出的解释是torch.cat — PyTorch 2.3 documentation.也就是说这个函数在给定维度中会连接给定的张量序列。 但是除了拼接维数dim以外,其他维数都必须一致,才可以进行对齐从而进行拼接。
因此这个问题的出现一定是数据集有问题,那么
1.数据集较少的小伙伴可以检查下自己的数据集格式、具体内容或者设置方面是不是有问题;
2.数据集较多的小伙伴可以先尝试用try跳过这个错误,从而排查一下代码有没有其他问题,有时候可能解决掉之后这个就跟着解决了。
以下是我在另一位作者的评论区里找到的,大家可以参考借鉴一下:
try:
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
except:
print("Error: 空的tensor列表无法进行拼接")
stats = {k: np.array([]) for k in self.stats.keys()}
for k, v in self.stats.items():