错误源代码
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, input):
x = input[0]
h = input[1]
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
input = (x,h)
traced_cell = torch.jit.trace(my_cell, input)
print(traced_cell)
traced_cell(x, h)
正确源代码
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x,h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
input = (x, h)
traced_cell = torch.jit.trace(my_cell, input)
print(traced_cell)
traced_cell(x, h)