7. Pytorch入门教程——在Cifar10数据集上训练两个不同的预训练迁移模型

现在我们可以继续进行测试和实验。但在此之前,我们应该将代码移动到.py文件并将其作为模块导入。

**总共创建了四个文件:

  1. model.py (包含核心Network类)
  2. fc.py (包含FC类)
  3. cv_model.py (包含TransferNetworkImg类)
  4. utils.py (包含所有不属于任何类的实用函数)**

我们还应该使用我们的Jupyter notebook的一个特殊指令,监视并重新加载所有导入的文件。如果我们出于某些原因(如修复bug)而修改任何文件,这将派上用场。

from mylib.utils import *
from mylib.model import *
from mylib.cv_model import *
from mylib.fc import *
from mylib.chkpoint import *

%load_ext autoreload
%autoreload 2

一、测试和实验

在以下代码中,我们将按顺序执行以下步骤:

  • 创建我们的类字典以及传递给迁移学习构造函数的头部字典;
  • 创建一个Densenet迁移学习对象;
  • 解冻它;
  • 拟合3个epochs训练;
  • 保存check-point;
  • 将它加载回另一个变量;
  • 再次解冻,重复3个epochs;
  • 再次保存check-point;
  • 重新加载到另一个变量;
  • 冻结,再训练3个epochs;
  • 再次保存模型。

首先载入数据:

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']
class_dict = {k:v for k,v in enumerate(classes)}

head={
       'num_outputs':10,
       'layers':[],
       'class_names':class_dict,
       'non_linearity':'relu',
       'model_type':'classifier',
       'model_name':'FC'
     }
train_dataset = datasets.CIFAR10('Cifar10', train=True,
                              download=True)

test_dataset = datasets.CIFAR10('Cifar10', train=False,
                             download=True)

num_train = len(train_dataset)
indices = list(range(num_train))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50,sampler=SubsetRandomSampler(indices),
                                           num_workers=0)
trainloader,validloader,testloader = split_image_data(train_dataset,test_dataset,batch_size=50)

train_transform = transforms.Compose([transforms.Resize(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomRotation(10),
                                     ])

train_dataset = datasets.CIFAR10('Cifar10',download=False,transform=train_transform)
transform = transforms.Compose([transforms.ToTensor()])

dataset = datasets.CIFAR10('Cifar10',download=False,transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=50,num_workers=0)
cifar10_mean = [0.4915, 0.4823, 0.4468]
cifar10_std  = [0.2470, 0.2435, 0.2616]
batch_size = 50

'''
ToTensor() converts a numpy array (all our images are constructed as 
numpy arrays by the Dataset class when read from disk).

Normalize() is another transform that normalizes according to the passed 
values of Means and STD of each channel as separate lists or tuples.
'''
train_transform = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomRotation(10),
                                      transforms.ToTensor(),
                                      transforms.Normalize(cifar10_mean, cifar10_std)
                                     ])

test_transform = transforms.Compose([transforms.Resize((224,224)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(cifar10_mean, cifar10_std)
                                    ])

train_data = datasets.CIFAR10('Cifar10', train=True,
                              download=False, transform=train_transform)
test_data = datasets.CIFAR10('Cifar10', train=False,
                             download=False, transform=test_transform)

trainloader,validloader,testloader = split_image_data(train_data,test_data,batch_size=batch_size)
transfer_densenet = TransferNetworkImg(model_name='DenseNet',
                   optimizer_name = 'Adadelta',               
                   best_accuracy_file ='densenet_best_accuracy_cifar10.pth',
                   chkpoint_file ='densenet_cifar10_chkpoint_file',
                   head = head
                   )

set_transfer_model: self.Model set to DenseNet
setting optim Ada Delta
DenseNet: setting head: inputs: 1024 hidden:[] outputs: 10
Transfer: best accuracy = 0.000
setting optim Ada Delta

下面进行训练拟合,虽然只有三个epochs,但还是强烈建议使用GPU,不然速度比较慢。

transfer_densenet.unfreeze()
transfer_densenet.fit(trainloader,validloader,epochs=3,print_every=200)

updating best accuracy: previous best = 83.780 new best = 86.900

transfer_densenet.save_chkpoint()

get_model_params: best accuracy = 86.900
get_model_params: chkpoint file = densenet_cifar10_chkpoint_file
checkpoint created successfully in densenet_cifar10_chkpoint_file

transfer_densenet2 = load_chkpoint('densenet_cifar10_chkpoint_file')

load_chkpoint: best accuracy = 86.900
set_transfer_model: self.Model set to DenseNet
setting optim Ada Delta
DenseNet: setting head: inputs: 1024 hidden:[] outputs: 10
Transfer: best accuracy = 86.900
setting optim Ada Delta

transfer_densenet2.unfreeze()
transfer_densenet2.fit(trainloader,validloader,epochs=3,print_every=200)

updating best accuracy: previous best = 90.460 new best = 91.220

transfer_densenet2.save_chkpoint()

get_model_params: best accuracy = 91.220
get_model_params: chkpoint file = densenet_cifar10_chkpoint_file
checkpoint created successfully in densenet_cifar10_chkpoint_file

这一次,经过解冻和训练,我们又达到了91%的准确率

transfer_densenet3 = load_chkpoint('densenet_cifar10_chkpoint_file')

load_chkpoint: best accuracy = 91.220
set_transfer_model: self.Model set to DenseNet
setting optim Ada Delta
DenseNet: setting head: inputs: 1024 hidden:[] outputs: 10
Transfer: best accuracy = 91.220
setting optim Ada Delta

transfer_densenet3.freeze()
transfer_densenet3.fit(trainloader,validloader,epochs=3,print_every=200)

updating best accuracy: previous best = 93.000 new best = 93.190

transfer_densenet3.save_chkpoint()

get_model_params: best accuracy = 93.190
get_model_params: chkpoint file = densenet_cifar10_chkpoint_file
checkpoint created successfully in densenet_cifar10_chkpoint_file

在9个epochs后,6次解冻,3次冻结,达到93.190%

用Resnet34重复以上步骤

transfer_resnet = TransferNetworkImg(model_name='ResNet34',
                   optimizer_name = 'Adadelta',               
                   best_accuracy_file ='resnet34_best_accuracy_cifar10.pth',
                   chkpoint_file ='resnet34_cifar10_chkpoint_file',
                   head = head
                   )

set_transfer_model: self.Model set to ResNet34
setting optim Ada Delta
ResNet34: setting head: inputs: 512 hidden:[] outputs: 10
Transfer: best accuracy = 0.000
setting optim Ada Delta

transfer_resnet.unfreeze()
transfer_resnet.fit(trainloader,validloader,epochs=3,print_every=200)

updating best accuracy: previous best = 82.680 new best = 86.730

transfer_resnet.save_chkpoint()

get_model_params: best accuracy = 86.730
get_model_params: chkpoint file = resnet34_cifar10_chkpoint_file
checkpoint created successfully in resnet34_cifar10_chkpoint_file

transfer_resnet2 = load_chkpoint('resnet34_cifar10_chkpoint_file')

load_chkpoint: best accuracy = 86.730
set_transfer_model: self.Model set to ResNet34
setting optim Ada Delta
ResNet34: setting head: inputs: 512 hidden:[] outputs: 10
Transfer: best accuracy = 86.730
setting optim Ada Delta

transfer_resnet2.unfreeze()
transfer_resnet2.fit(trainloader,validloader,epochs=3,print_every=200)

updating best accuracy: previous best = 86.940 new best = 89.830

transfer_resnet2.save_chkpoint()

get_model_params: best accuracy = 89.830
get_model_params: chkpoint file = resnet34_cifar10_chkpoint_file
checkpoint created successfully in resnet34_cifar10_chkpoint_file

transfer_resnet3 = load_chkpoint('resnet34_cifar10_chkpoint_file')

load_chkpoint: best accuracy = 89.830
set_transfer_model: self.Model set to ResNet34
setting optim Ada Delta
ResNet34: setting head: inputs: 512 hidden:[] outputs: 10
Transfer: best accuracy = 89.830
setting optim Ada Delta

transfer_resnet3.freeze()
transfer_resnet3.fit(trainloader,validloader,epochs=3,print_every=200)
transfer_resnet3.save_chkpoint()

get_model_params: best accuracy = 92.780
get_model_params: chkpoint file = resnet34_cifar10_chkpoint_file
checkpoint created successfully in resnet34_cifar10_chkpoint_file

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值