attribute 'cx' of type 'Tensor' is not usable in a script method
import torch
class ConstantTensor(torch.jit.ScriptModule):
def __init__(self):
super(ConstantTensor, self).__init__()
self.cx = torch.ones(100, 100, dtype=torch.float, device='cuda')
@torch.jit.script_method
def forward(self, x):
return x + self.cx
c = ConstantTensor()
print(c.graph)
正确代码:
https://github.com/pytorch/pytorch/issues/16284
# self.cx = torch.ones(100, 100, dtype=torch.float, device='cuda')
self.register_buffer('cx', torch.ones(100, 100, dtype=torch.float, device='cuda'))