import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print
pytorch中Linear类中weight的形状问题源码探讨
最新推荐文章于 2024-08-21 11:23:34 发布
import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print