1.问题
使用pytorch中torch.nn.functional模块时,出现以下报错:
问题代码位于:
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
inputs = F.normalize(inputs, mean, std)
targets = F.normalize(targets, mean, std)
想要利用F.normalize进行归一化操作,传入参数为预定义好的均值mean和标准差std。但是F.normalize函数想要传入的参数并不是这两个,因此就会出现上述错误。即F.normalize希望传入上面截图中的6种参数,但是实际接收到的是(Tensor, list, list, keepdim=bool)类型,因此报错。博主用的torch版本是torch 1.8。
2.解决办法
将F.normalize函数改为torchvision.transform.Normalize即可。
import torchvision.transform as transform
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
trans = transform.Normalize(mean, std) # 构建一个归一化操作并命名为trans
inputs = trans(inputs) # 将input传入trans即可
以上便可实现与原F.normalize一样功能。