PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:
torchvision.datasets
torchvision.models
torchvision.transforms
torchvision.models 使用。以 vgg16为例子
- 导入预训练模型:
import torchvision
model = torchvision.models.vgg16(pretrained=True)
2) 只导入网络结构,不导入参数:
model = torchvision.models.vgg16(pretrained=False) #主要是这里改为False
3) 由于 pretrained 参数默认是 False,所以 2) 等价于:
model = torchvision.models.vgg16()
transforms.ToTensor()
(1) transforms.ToTensor() 将numpy的ndarray或PIL.Image读的图片转换成形状为(C,H, W)的Tensor格式,且/255归一化到[0,1.0]之间
class 网络名字(nn.Module):
def init(self, 一些定义的参数):
super(网络名字, self).init()
self.layer1 = nn.Linear(num_input, num_hidden)
self.layer2 = nn.Sequential(…)
…
定义需要用的网络层
def forward(self, x): # 定义前向传播
x1 = self.layer1(x)
x2 = self.layer2(x)
x = x1 + x2
...
return x