使用TorchScript,如果直接再forward里调用flattern_parameters()会遇到报错。
解决方法:
因为flattern_parameters()只能在GPU上使用, 所以可以这样写
class Net:
def __init__(device):
...
self.rnn = nn.GRU(656, 1500, 2, batch_first=True).to(device)
...
self.rnn.flattern_parameters()
def forward(self, x):
...