Warning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead
查看代码中加载训练数据集的地方,在生成mask标签的函数中将return中mask换为mask.bool()即可:
heatmaps, mask = generate_label_map(Hpoint, height//self.downsample, width//self.downsample, self.sigma, self