以自己写的vgg模型为例:
import torch
import torch.nn as nn
import numpy as np
import collections
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, bn=0, relu=1, pooling=0):
modules = [nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)]
if bn:
modules.append(nn.BatchNorm2d(out_channels))
if relu:
modules.append(nn.ReLU(inplace=True))
if pooling:
modules.append(nn.MaxPool2d(2,2))
return nn.Sequential(*modules)
def fc(in_channels, out_channels,relu=0):
modules = [nn.Linear(in_channels,out_channels)]
if relu:
modules.append(nn.ReLU(inplace=True))
return nn.Sequential(*modules)
def conv_dw(in_channels,out_channels):
"""
深度可分解卷积
"""
return nn.Sequential(
nn.Conv2d(in_channels,in_channels,3,padding=1,groups=in_channels),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels,out_channels,1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def load_weight(model, path):
checkpoint = torch.load(path)
load_state = checkpoint['state_dict']
model_state = model.state_dict()
new_state = collections.OrderedDict()
for k in model_state.keys():
if k in load_state and load_state[k].size() == model_state[k].size():
new_state[k] = load_state[k]
model.load_state_dict(new_state)
class vgg19(nn.Module):
def __init__(self,classes_num):
super(vgg19,self).__init__()
self.classes_num = classes_num
self.seq = nn.Sequential(
conv(3,64,3),
conv(64,64,3,pooling=1),
conv(64,128,3),
conv(128,128,3,pooling=1),
conv(128,256,3),
conv(256,256,3),
conv(256,256,3),
conv(256,256,3,pooling=1),
conv(256,512,3),
conv(512,512,3),
conv(512,512,3),
conv(512,512,3,pooling=1),
)
self.fc = nn.Sequential(
fc(512*7*7,4096),
fc(4096,4096),
fc(4096,self.classes_num,relu=0)
)
def forward(self, x):
x = self.seq(x)
x = x.reshape(-1,512*7*7)
x = self.fc(x)
return x
if __name__ == '__main__':
image = torch.randn(1,3,224,224)
model = vgg19(1000)
out = model(image)
print(out.shape)
一些必要的解释:
nn.Sequential()
可以把若干个操作包括在一个容器内,在forward()函数里就不用再挨个把对输入的操作复述一遍。- 通常来说,每个3x3的conv2d后面都会加上bn和relu,因此可以先自定义一个conv函数,把bn和relu都加进去,能够有效减少代码重复量。
- python中,带×号(星号)的意思表示将参数的所有内容依次作为形参传入,带××(两个星号)的意思是将字典型的所有内容的数值作为形参传入。
nn.ReLU(inplace=True)
的意思是直接改变输入,而不另创一个新的输出。可以节省空间和时间。- 调用
collections.OrderedDict()
函数可以生成一个有顺序的字典。一般的字典是无序的,输出是顺序不一定。