torch.nn用法理解和简单案例
-
用nn创建一个线性网络:
import torch import torch.nn as nn m = nn.Linear(20, 30) # 定义一个输入20维,输出30维的线性网络,y=Ax+b,即A是20*30的矩阵 input = torch.randn(128, 20) # 输入数据的大小是128*20且服从正态分布 output = m(input) print(output.size()) # torch.Size([128, 30])
-
用nn进行自适应平均池化,括号中的参数只需要目标的尺寸即可,该方法可以自动匹配输入的尺寸生成池化的大小及步长
# target output size of 5x7 m = nn.AdaptiveAvgPool2d((5,7)) input = torch.randn(1, 64, 8, 9) output = m(input) output.shape # torch.Size([1, 64, 5, 7]) # target output size of 7x7 (square) m = nn