import torch
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
创建模型实例
model = LinearModel()
准备输入数据
x = torch.tensor([[1.0], [2.0]])
运行模型
y = model(x)
将模型转换为Torch Script
scripted_model = torch.jit.script(model)
使用Torch Script进行推理
y_ts = scripted_model(x)
比较两种方式的输出是否相同
print("PyTorch output:", y)
print("Torch Script output:", y_ts)