# define the model
model = Model()
for k, v in model.named_parameters():
print(k, v.size())
#===================================#
pretrained_state = torch.load('pretrained_model.pth')
for i in pretrained_state:
print(i, pretrained_state[i].size())
PyTorch载入模型,并输出参数
最新推荐文章于 2024-05-19 19:02:56 发布