【PyTorch教程】P25 pytorch中现有模型

P25 pytorch中现有模型

  • 位置:
    在这里插入图片描述

  • 预训练的意思pretrain,是已经在ImageNet数据集上训练好的:progress是对下载的管理:
    在这里插入图片描述

  • 使用的dataset,ImageNet:需要安装scipy库:

在这里插入图片描述

  • 点开这个ImageNet看里面的信息:
    在这里插入图片描述

  • 里面的重要信息:
    在这里插入图片描述

  • 转而使用已经训练好的model:
    在这里插入图片描述

  • 上图:false意思是不下载已经在ImageNet里面训练好的模型,即conv、pooling layers里面的那些参数,而true就要下载他们。
    对比二者的参数:

在这里插入图片描述
在这里插入图片描述

  • 使用vgg16,用在CIFAR数据集上,进行分类:
    Vgg16训练时,用的是ImageNet数据集,它把数据分为1000个类,而CIFAR把数据分为10类,那么就有两种做法,来利用vgg16来处理 CIFAR数据集:1、vgg16后面加一个新的线性层,使1000映射到10;2、直接把vgg16最后的输出层改为10类:这里的add_module是集成 - 在pytorch当中的方法了,直接用:
    下图是第一种方法:
    在这里插入图片描述

  • 还有下面这种写法,可以把新添加的层,放在classifier的框架底子,变成classifier的子集,原来是在大的框架vgg的直属下面:
    在这里插入图片描述

  • 下面是第二个方法:替换原来的输出类型数:图中圈2:
    在这里插入图片描述

可以运行的代码

# -*- coding: utf-8 -*-

import torchvision

train_data = torchvision.datasets.ImageNet("../data_image_net", split='train', download=True,
                                           transform=torchvision.transforms.ToTensor())

from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

# print(vgg16_true)

'''
print的结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)    # 由于 imagenet 数据集,他的分类结果是 1000,所以这里out_features 值为1000
  )                                                                # 要想用于 CIFAR10 数据集, 可以在网络下面多加一行,转成10分类的输出         
)
'''

train_data = torchvision.datasets.CIFAR10('../dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)

# vgg16_true.add_module('add_linear',nn.Linear(1000, 10))
# 要想用于 CIFAR10 数据集, 可以在网络下面多加一行,转成10分类的输出,这样输出的结果,跟下面的不一样,位置不一样

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
# 层级不同
# 如何利用现有的网络,改变结构
print(vgg16_true)

# 上面是添加层,下面是如何修改VGG里面的层内容
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)  # 中括号里的内容,是网络输出结果自带的索引,套进这种格式,就可以直接修改那一层的内容
print(vgg16_false)


'''
这个教程,可以自己修改别人已经写好了的模型,或者在里面添加自己的需求
'''

完整目录

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值