import torch
class Net(torch.nn.Module):
def __init__(self, N, M):
super(Net, self).__init__()
self.W = torch.nn.Parameter(torch.randn(N, M))
self.b = torch.nn.Parameter(torch.randn(M))
def forward(self, input):
return torch.addmm(self.b, input, self.W)
import torch
class Net(torch.nn.Module):
def __init__(self, N, M):
super(Net, self).__init__()
self.W = torch.nn.Parameter(torch.randn(N, M))
self.b = torch.nn.Parameter(torch.randn(M))
def forward(self, input):
return torch.addmm(self.b, input, self.W)
分析以上代码,可以发现,python和c++主要是在寄存变量的时候的区别(寄存变量就是把变量放在CPU里面,可以更快的读取)。python只要把变量变成类的属性(就是self.w的形式),就可以自动把那些变量放在寄存器里。而c++需要手动把这些变量放在寄存器里。
寄存submodules的时候,如下代码,只需要把submudule定义为module的属性(attribute)的时候,就可以自动的把submodul