import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print
pytorch中Linear类中weight的形状问题源码探讨
最新推荐文章于 2023-03-31 15:05:14 发布
import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print