- 下载网络模型
import torchvision
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("ok")
print(vgg16_true)
- 查看函数的方法
import torchvision
help(torchvision.models.vgg16)
3. 网络模型添加
```python
import torchvision
from torch import nn
dataset = torchvision.datasets.CIFAR10("./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)
- 网络模型修改
import torchvision
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)