Pytorch预训练模型

Pytorch预训练模型

Pytorch支持预训练的模型如下

From torchvision package:

  • ResNet (resnet18resnet34resnet50resnet101resnet152)
  • DenseNet (densenet121densenet169densenet201densenet161)
  • Inception v3 (inception_v3)
  • VGG (vgg11vgg11_bnvgg13vgg13_bnvgg16vgg16_bnvgg19vgg19_bn)
  • SqueezeNet (squeezenet1_0squeezenet1_1)
  • AlexNet (alexnet)

From Pretrained models for PyTorch package:

  • ResNeXt (resnext101_32x4dresnext101_64x4d)
  • NASNet-A Large (nasnetalarge)
  • NASNet-A Mobile (nasnetamobile)
  • Inception-ResNet v2 (inceptionresnetv2)
  • Dual Path Networks (dpn68dpn68bdpn92dpn98dpn131dpn107)
  • Inception v4 (inception_v4)
  • Xception (xception)
  • Squeeze-and-Excitation Networks (senet154se_resnet50se_resnet101se_resnet152se_resnext50_32x4dse_resnext101_32x4d)
  • PNASNet-5-Large (pnasnet5large)
  • PolyNet (polynet)

Pytorch预训练运行环境

  • Python 3.5+
  • PyTorch 0.3+

安装方法

pip install cnn_finetune

用例

from cnn_finetune import make_model
from torch import optim
model = make_model(model_name='resnext101_64x4d', num_classes=10, pretrained=True)
# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.0001)

导入包,将make_model中的num_classes改成你项目中分类类别数。如果你选择的model是VGG系列或者其他需要设置固定输入格式大小的模型还会有input_size需要设置。需要将输入通过cv.resize完成或者特征提取成相应大小。

简单的完整示例

"""
!/usr/bin/env python
-*- coding:utf-8 -*-
Author: eric.lai
Created on 2019/7/22 9:48
"""
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from cnn_finetune import make_model
from torch import optim
import torch.nn as nn
import torch

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
model = make_model(model_name='resnext101_64x4d', num_classes=10, pretrained=True)
# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.0001)

def load_data():
    model.train()
    criterion = nn.MSELoss() # CrossEntropyLoss does not expect a one-hot encoded vector as the target
    total_loss = 0
    total_size = 0

    for i in range(100):
        optimizer.zero_grad()
        batch_xs, batch_ys = mnist.train.next_batch(10)
        batch_xs = list(batch_xs)*3
        batch_xs = np.array(batch_xs)
        batch_xs = np.reshape(batch_xs, [-1, 3, 28, 28])
        batch_xs = torch.from_numpy(batch_xs)
        batch_ys = torch.from_numpy(batch_ys)
        output = model(batch_xs)
        loss = criterion(output,target=batch_ys.float())
        total_loss += loss.item()
        total_size += batch_xs.shape[0]
        loss.backward()
        optimizer.step()
        print(loss)
        test()

def test():
    model.eval()
    with torch.no_grad():
        batch_xs, batch_ys = mnist.test.next_batch(100)
        batch_xs = list(batch_xs) * 3
        batch_xs = np.array(batch_xs)
        batch_xs = np.reshape(batch_xs, [-1, 3, 28, 28])
        batch_xs = torch.from_numpy(batch_xs)
        output = model(batch_xs)
        correct = list(np.equal(np.argmax(output, axis=1),np.argmax(batch_ys, axis=1)))
        nums = correct.count(1)
    print('test accuracy: ', nums / len(batch_xs))
    return output

if __name__ == '__main__':
    load_data()

需要注意的几点:

1.Pytorch是channel_first,在reshape的时候要把通道数放在第一个位置

2.输入的数据要通过torch_from_numpy转换成Tensor

参考来源

https://github.com/Laicheng0830/pytorch-cnn-finetune#from-torchvision-package

  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值