如何找到对应模型的权重文件?
我们想要利用迁移学习去加载pytorch官方训练好的文件权重,首先要做的第一步就是下载该文件。
import torchvision.models.model_name
我们在pycharm中按住Ctrl键,鼠标左键点击model_name跳转一下,找到对应模型的url链接,把它贴到浏览器里面进行下载即可。
加载官方给的预训练文件
这里使用一段Alexnet训练代码举例
# import os
# import sys
# import json
# import torch
# import torch.nn as nn
# from torchvision import transforms, datasets, utils
# import matplotlib.pyplot as plt
# import numpy as np
# import torch.optim as optim
# from tqdm import tqdm
# from model import AlexNet
#
# import torchvision.models as models
#
#
#
# def main():
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print("using {} device.".format(device))
#
# data_transform = {
# "train": transforms.Compose([transforms.RandomResizedCrop(224),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
# "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
#
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
# image_path = os.path.join(data_root, "data_set", "data") # flower data set path
# assert os.path.exists(image_path), "{} path does not exist.".format(image_path)#检查照片是否存在当前路径下
# train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
# transform=data_transform["train"])
# train_num = len(train_dataset)
#
# # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
# flower_list = train_dataset.class_to_idx
# #这一步是将字典键和值颠倒位置,即{0:'daisy'......}
# cla_dict = dict((val, key) for key, val in flower_list.items())
# # write dict into json file
# json_str = json.dumps(cla_dict, indent=4)#indent是进行缩进
# with open('class_indices.json', 'w') as json_file:
# json_file.write(json_str)
#
# batch_size = 64
# nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
# print('Using {} dataloader workers every process'.format(nw))
#
# train_loader = torch.utils.data.DataLoader(train_dataset,
# batch_size=batch_size, shuffle=True,
# num_workers=0)
#
# validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
# transform=data_transform["val"])
# val_num = len(validate_dataset)
# validate_loader = torch.utils.data.DataLoader(validate_dataset,
# batch_size=4, shuffle=False,
# num_workers=0)
#
# print("using {} images for training, {} images for validation.".format(train_num,
# val_num))
# # test_data_iter = iter(validate_loader)
# # test_image, test_label = test_data_iter.next()
# #
# # def imshow(img):
# # img = img / 2 + 0.5 # unnormalize
# # npimg = img.numpy()
# # plt.imshow(np.transpose(npimg, (1, 2, 0)))
# # plt.show()
# #
# # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# # imshow(utils.make_grid(test_image))
#
# net = AlexNet(num_classes=6, init_weights=True)
#
# net.to(device)
#trasfer learning
net = AlexNet()
model_weights_path = './alexnet_pre.pth'
assert os.path.exists(model_weights_path),'{} do not exist'.format(model_weights_path)
net.load_state_dict(torch.load(model_weights_path, map_location='cpu'))
in_channel = net.classifier[6].in_features
net.classifier[6] = nn.Linear(in_channel, 5)
net.to(device)
# loss_function = nn.CrossEntropyLoss()
# # pata = list(net.parameters())
# optimizer = optim.Adam(net.parameters(), lr=0.0002)
#
# epochs = 10
# save_path = './AlexNet.pth'
# best_acc = 0.0
# train_steps = len(train_loader)
# for epoch in range(epochs):
# # train
# net.train()
# running_loss = 0.0
# train_bar = tqdm(train_loader, file=sys.stdout)
# for step, data in enumerate(train_bar):
# images, labels = data
# optimizer.zero_grad()
# outputs = net(images.to(device))
# loss = loss_function(outputs, labels.to(device))
# loss.backward()
# optimizer.step()
#
# # print statistics
# running_loss += loss.item()
#
# train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
# epochs,
# loss)
#
# # validate
# net.eval()
# acc = 0.0 # accumulate accurate number / epoch
# with torch.no_grad():
# val_bar = tqdm(validate_loader, file=sys.stdout)
# for val_data in val_bar:
# val_images, val_labels = val_data
# outputs = net(val_images.to(device))
# predict_y = torch.max(outputs, dim=1)[1]
# acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
#
# val_accurate = acc / val_num
# print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
# (epoch + 1, running_loss / train_steps, val_accurate))
#
# if val_accurate > best_acc:
# best_acc = val_accurate
# torch.save(net.state_dict(), save_path)
#
# print('Finished Training')
#
#
# if __name__ == '__main__':
# main()
#trasfer learning
net = AlexNet()#实例化网络,该网络是在Imagnet上训练的,默认节点是1000
model_weights_path = './alexnet_pre.pth'#我们将下载的官方权重保存为alexnet_pre
assert os.path.exists(model_weights_path),'{} do not exist'.format(model_weights_path)#判断该权重文件在不在
net.load_state_dict(torch.load(model_weights_path, map_location='cpu'))#加载该权重文件
in_channel = net.classifier[6].in_features #获取最后一层线性层的in_channel
net.classifier[6] = nn.Linear(in_channel, 5)#将最后一层linear层进行修改,5是你所要图像进行分类的个数
net.to(device)