import torch
import torchvision.models as models
import os
# 指定模型保存路径
model_save_dir = './models'
# 检查模型保存目录是否存在,如果不存在则创建
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
print("文件夹创建成功")
# 指定要下载的预训练模型
model_name = 'resnet101'
pretrained_model = models.__dict__[model_name](pretrained=True)
# 将模型保存到指定目录
model_path = os.path.join(model_save_dir, f'{model_name}_pretrained.pth')
torch.save(pretrained_model.state_dict(), model_path)
# # 加载保存的模型
# model = models.__dict__[model_name](pretrained=False)
# model.load_state_dict(torch.load(model_path))
# # 如果需要,将模型设置为评估模式
# model.eval()
# # 现在模型已经加载并可以使用了
# # ...(在这里添加您的代码来使用模型)