以vgg16网络为例,介绍导入已有训练好(直接用来预测)和未训练好的网络(用其框架);还介绍了对已导入网格的修改和添加操作,从而做出适用于自己问题的网络。
注:目前的对网络的操作包括修改和添加。
# 可以导入已有训练好(直接用)的或者未训练好的模型(仅用其框架)
# 可以对框架进行修改,添加操作从而便于自己问题的应用
# 以vgg16为例
import torchvision
# 1、加载已有模型
# 下载训练好的数据集
# vgg16_trained = torchvision.models.vgg16(pretrained=True)
# 调用未训练好的数据集,仅用其框架
from torch import nn
vgg16_untrained = torchvision.models.vgg16(pretrained=False)
print(vgg16_untrained)
# 2、修改已有模型
# 2.1 在结尾添加新的层,比如Affine层
vgg16_untrained.add_module('new_Linear', nn.Linear(1000, 10)) # 'new_Linear'为新层的名字
print(vgg16_untrained)
# 2.2 在某一个模块(sequential)中增加
vgg16_untrained.classifier.add_module('new', nn.Linear(1000, 1000)) # classifier是VGG中的一个sequential名
print(vgg16_untrained)
# 2.3 修改某一个模块(sequential)中修改
vgg16_untrained.classifier[7] = nn.Linear(1000, 10) # 修改sequential中的某一个可用[],不管叫什么(例如new),只按照序号7
vgg16_untrained.new_Linear = nn.Linear(10, 10) # 修改单独模块,此时不用索引
print(vgg16_untrained)