迁移学习:如何利用将pytorch官方训练好的模型权重加载进自己的模型

如何找到对应模型的权重文件?

        我们想要利用迁移学习去加载pytorch官方训练好的文件权重,首先要做的第一步就是下载该文件。

import torchvision.models.model_name

        我们在pycharm中按住Ctrl键,鼠标左键点击model_name跳转一下,找到对应模型的url链接,把它贴到浏览器里面进行下载即可。

加载官方给的预训练文件

        这里使用一段Alexnet训练代码举例

# import os
# import sys
# import json
# import torch
# import torch.nn as nn
# from torchvision import transforms, datasets, utils
# import matplotlib.pyplot as plt
# import numpy as np
# import torch.optim as optim
# from tqdm import tqdm
# from model import AlexNet
#
# import torchvision.models as models
#
#
#
# def main():
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#     print("using {} device.".format(device))
#
#     data_transform = {
#         "train": transforms.Compose([transforms.RandomResizedCrop(224),
#                                      transforms.RandomHorizontalFlip(),
#                                      transforms.ToTensor(),
#                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
#         "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
#                                    transforms.ToTensor(),
#                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
#
#     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
#     image_path = os.path.join(data_root, "data_set", "data")  # flower data set path
#     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)#检查照片是否存在当前路径下
#     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
#                                          transform=data_transform["train"])
#     train_num = len(train_dataset)
#
#     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
#     flower_list = train_dataset.class_to_idx
#     #这一步是将字典键和值颠倒位置,即{0:'daisy'......}
#     cla_dict = dict((val, key) for key, val in flower_list.items())
#     # write dict into json file
#     json_str = json.dumps(cla_dict, indent=4)#indent是进行缩进
#     with open('class_indices.json', 'w') as json_file:
#         json_file.write(json_str)
#
#     batch_size = 64
#     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
#     print('Using {} dataloader workers every process'.format(nw))
#
#     train_loader = torch.utils.data.DataLoader(train_dataset,
#                                                batch_size=batch_size, shuffle=True,
#                                                num_workers=0)
#
#     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
#                                             transform=data_transform["val"])
#     val_num = len(validate_dataset)
#     validate_loader = torch.utils.data.DataLoader(validate_dataset,
#                                                   batch_size=4, shuffle=False,
#                                                   num_workers=0)
#
#     print("using {} images for training, {} images for validation.".format(train_num,
#                                                                            val_num))
#     # test_data_iter = iter(validate_loader)
#     # test_image, test_label = test_data_iter.next()
#     #
#     # def imshow(img):
#     #     img = img / 2 + 0.5  # unnormalize
#     #     npimg = img.numpy()
#     #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     #     plt.show()
#     #
#     # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
#     # imshow(utils.make_grid(test_image))
#
#     net = AlexNet(num_classes=6, init_weights=True)
#
#     net.to(device)

    #trasfer learning
    net = AlexNet()
    model_weights_path = './alexnet_pre.pth'
    assert os.path.exists(model_weights_path),'{} do not exist'.format(model_weights_path)
    net.load_state_dict(torch.load(model_weights_path, map_location='cpu'))
    in_channel = net.classifier[6].in_features
    net.classifier[6] = nn.Linear(in_channel, 5)
    net.to(device)

#     loss_function = nn.CrossEntropyLoss()
#     # pata = list(net.parameters())
#     optimizer = optim.Adam(net.parameters(), lr=0.0002)
#
#     epochs = 10
#     save_path = './AlexNet.pth'
#     best_acc = 0.0
#     train_steps = len(train_loader)
#     for epoch in range(epochs):
#         # train
#         net.train()
#         running_loss = 0.0
#         train_bar = tqdm(train_loader, file=sys.stdout)
#         for step, data in enumerate(train_bar):
#             images, labels = data
#             optimizer.zero_grad()
#             outputs = net(images.to(device))
#             loss = loss_function(outputs, labels.to(device))
#             loss.backward()
#             optimizer.step()
#
#             # print statistics
#             running_loss += loss.item()
#
#             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
#                                                                      epochs,
#                                                                      loss)
#
#         # validate
#         net.eval()
#         acc = 0.0  # accumulate accurate number / epoch
#         with torch.no_grad():
#             val_bar = tqdm(validate_loader, file=sys.stdout)
#             for val_data in val_bar:
#                 val_images, val_labels = val_data
#                 outputs = net(val_images.to(device))
#                 predict_y = torch.max(outputs, dim=1)[1]
#                 acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
#
#         val_accurate = acc / val_num
#         print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
#               (epoch + 1, running_loss / train_steps, val_accurate))
#
#         if val_accurate > best_acc:
#             best_acc = val_accurate
#             torch.save(net.state_dict(), save_path)
#
#     print('Finished Training')
#
#
# if __name__ == '__main__':
#     main()
#trasfer learning
    net = AlexNet()#实例化网络,该网络是在Imagnet上训练的,默认节点是1000
    model_weights_path = './alexnet_pre.pth'#我们将下载的官方权重保存为alexnet_pre
    assert os.path.exists(model_weights_path),'{} do not exist'.format(model_weights_path)#判断该权重文件在不在
    net.load_state_dict(torch.load(model_weights_path, map_location='cpu'))#加载该权重文件
    in_channel = net.classifier[6].in_features #获取最后一层线性层的in_channel
    net.classifier[6] = nn.Linear(in_channel, 5)#将最后一层linear层进行修改,5是你所要图像进行分类的个数
    net.to(device)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值