【自用】PyTorch实现神经网络模型的代码(小技巧

以自己写的vgg模型为例:
import torch
import torch.nn as nn
import numpy as np
import collections

def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, bn=0, relu=1, pooling=0):
    modules = [nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)]
    if bn:
        modules.append(nn.BatchNorm2d(out_channels))
    if relu:
        modules.append(nn.ReLU(inplace=True))
    if pooling:
        modules.append(nn.MaxPool2d(2,2))
    return nn.Sequential(*modules)

def fc(in_channels, out_channels,relu=0):
    modules = [nn.Linear(in_channels,out_channels)]
    if relu:
        modules.append(nn.ReLU(inplace=True))
    return nn.Sequential(*modules)

def conv_dw(in_channels,out_channels):
    """
    深度可分解卷积
    """
    return nn.Sequential(
        # point-wise
        nn.Conv2d(in_channels,in_channels,3,padding=1,groups=in_channels),
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace=True),
        
        # depth-wise
        nn.Conv2d(in_channels,out_channels,1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

def load_weight(model, path):
    checkpoint = torch.load(path)
    load_state = checkpoint['state_dict'] # 这一步还是需要先去看一下预训练模型中key长什么样
    model_state = model.state_dict()
    new_state = collections.OrderedDict() # 生成一个有顺序的字典
    for k in model_state.keys():
        if k in load_state and load_state[k].size() == model_state[k].size():
            new_state[k] = load_state[k]
    model.load_state_dict(new_state)
        
class vgg19(nn.Module):
    def __init__(self,classes_num):
        super(vgg19,self).__init__()
        self.classes_num = classes_num
        self.seq = nn.Sequential(
            conv(3,64,3),
            conv(64,64,3,pooling=1),
            conv(64,128,3),
            conv(128,128,3,pooling=1),
            conv(128,256,3),
            conv(256,256,3),
            conv(256,256,3),
            conv(256,256,3,pooling=1),
            conv(256,512,3),
            conv(512,512,3),
            conv(512,512,3),
            conv(512,512,3,pooling=1),
        )
        self.fc = nn.Sequential(
            fc(512*7*7,4096),
            fc(4096,4096),
            fc(4096,self.classes_num,relu=0)
        )
    def forward(self, x):
        x = self.seq(x)
        x = x.reshape(-1,512*7*7)
        x = self.fc(x)    
        return x

if __name__ == '__main__':
    image = torch.randn(1,3,224,224)
    model = vgg19(1000)
    out = model(image)
    print(out.shape)
一些必要的解释:
  1. nn.Sequential() 可以把若干个操作包括在一个容器内,在forward()函数里就不用再挨个把对输入的操作复述一遍。
  2. 通常来说,每个3x3的conv2d后面都会加上bn和relu,因此可以先自定义一个conv函数,把bn和relu都加进去,能够有效减少代码重复量。
  3. python中,带×号(星号)的意思表示将参数的所有内容依次作为形参传入,带××(两个星号)的意思是将字典型的所有内容的数值作为形参传入。
  4. nn.ReLU(inplace=True)的意思是直接改变输入,而不另创一个新的输出。可以节省空间和时间。
  5. 调用collections.OrderedDict()函数可以生成一个有顺序的字典。一般的字典是无序的,输出是顺序不一定。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值