import torch
from torch import nn
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(3, 32, 3, 1),
nn.Conv2d(32, 64, 3, 1)
)
def forward(self, x):
x = self.layer(x)
return x
def save(net, input, save_path):
net.eval()
traced_script_module = torch.jit.trace(net, input)
traced_script_module.save(save_path)
def load(model_path):
return torch.jit.load(model_path)
if __name__ == '__main__':
input = torch.Tensor(1, 3, 100, 100)
model_path = 'model.pt'
# # net.load_state_dict(torch.load(model_path))
# # save(net, input, './model.pt')
# net = model()
# out = net(input)
# print(out.size())
# save(net, input, 'model.pt')
# # net.load_state_dict(torch.load(model_path))
net = load(model_path)
out = net(input)
print(out.size())
【Pytorch】torch.jit
最新推荐文章于 2024-06-04 20:27:08 发布