import torch
import torch.nn as nn
input=torch.randn([32,49,768])
l=nn.Linear(768,512)
out=l(input)
print(out.shape)
# torch.Size([32, 49, 512])
# l=nn.Linear(49,512)
# mat1 and mat2 shapes cannot be multiplied (1568x768 and 49x512)
# 说明了执行linear时,输入的channel只能位于最后一维
b=nn.BatchNorm1d(49)
out=b(out)
print(out.shape)
# torch.Size([32, 49, 512])
# b=nn.BatchNorm1d(512)
# RuntimeError: running_mean should contain 49 elements not 512
# 说明了执行linear时,输入的channel只能位于最后中间
nn.Linear和nn.BatchNorm1的维度问题
于 2021-09-10 17:10:57 首次发布