【PyTorch】 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)

概述

你还在为训练无从下手而苦恼么?
你还在为模型训练时间漫长而痛苦么?
你还在为模型准确率提升困难在深夜一个人啜泣么?

在这里插入图片描述

今天教大家一个方法, 使得我们的模型起跑线上直接甩开别人几条街. 隔壁王叔叔都学会了!

迁移学习

迁移学习 (Transfer Learning) 是把已学训练好的模型参数用作新训练模型的起始参数.

入住 GitHub

经过几天的日夜狂肝, 本人完成了在 GitHub 上的第一个项目. 把迁移学习封装成了一个有手就能用的黑盒模型.

在这里插入图片描述
大家只要替换自己的数据集就可以实现多个可选模型迁移学习并自动保存. 就是两个字简单

项目详解

GitHub 链接
在这里插入图片描述

get_data.py (获取数据)

目前支持 MNIST, Fashion MNIST, CIFAR 10 和 CIFAR 100 数据集.

可以在```get_data.py``下自行替换成自己需要的数据集:
在这里插入图片描述

传入数据的格式为:

data_loader = {"train": train_loader, "valid": test_loader}

get_model (获取模型)

目前支持:

  • resnet18
  • resnet34
  • resnet50
  • resnet101
  • resnet152
  • alexnet
  • squeezenet
  • vgg11
  • vgg13
  • vgg16
  • vgg19

替换模型的方法:

python main.py --model_name "模型名称"

例如, 使用 vgg 13:

python main.py --model_name vgg13

例如, 使用 resnet 152:

python main.py --model_name resnet152

参数详解

必填参数:

  • model_name: 模型名称, 类型为 string
  • num_classes: 输出类别数, 类型为 int (例如 MNIST 是 10 分类, CIFAR 100 是 100 分类)

重要参数:

  • data_name: 数据名称, 类型为 string, 默认为 CIFAR10
  • data_gray: 是否为灰度图, 类型为 boolean, 默认为 False
  • num_epochs: 迭代次数, 类型为 int, 默认为 20
  • batch_size: 一个批次的样本数目, 默认为 512

可选参数 (不建议修改):

  • feature_exact: 是否冻层, 类型为 boolean, 默认为 False
  • use_pretrained: 是否使用预训练权重, 类型为 boolean, 默认为 True
  • pretrained_model_path: 预训练权重, 类型为 string, 默认为 pretrained_model/
  • model_save_path: 模型保存路径, 类型为 string, 默认为 “checkpoint/”
  • visualize: 模型可视化, 类型为 boolean, 默认为 True

使用说明

首先我们需要cd到文件路径, 例如:

cd C:\Users\Windows\Desktop\Project\transfer_learning-main

训练 MNIST

使用 resnet18 训练 MNIST 数据集:

python main.py --data_name MNIST --data_gray True --model_name resnet18 --num_classes 10 --batch_size 512

训练 Fashion MNIST

使用 resnet34 训练 Fashion MNIST 数据集:

python main.py --data_name FashionMNIST --data_gray True --model_name resnet34 --num_classes 10 --batch_size 512

训练 CIFAR 10

使用 resnet50 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR10 --model_name resnet50 --num_classes 10 --batch_size 512

训练 CIFAR 100

使用 resnet152 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR100 --model_name resnet152 --num_classes 100 --batch_size 512

训练自己的数据

python main.py --data_name other --model_name ? --num_classes ? --batch_size ? --epochs ?
评论 38
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值