import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
from resnet18 import ResNet18
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 参数设置,使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
#然后创建一个解析对象;然后向该对象中添加你要关注的命令行参数和选项,每一个add_argument方法对应一个你要关注的参数或选项;最后调用parse_args()方法进行解析;
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--outf', default='./model18/', help='folder to output images and model checkpoints') #输出结果保存路径
parser.add_argument('--net', default='./model18/Resnet18.pth', help="path to net (to continue training)") #恢复训练时的模型路径
args = parser.parse_args()
# 超参数设置
EPOCH = 135 #遍历数据集次数
pre_epoch = 0 # 定义已经遍历数据集的次数
BATCH_SIZE = 128 #批处理尺寸(batch_size)
LR = 0.1 #学习率
# 准备数据集并预处理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), #先四周填充0,在吧图像随机裁剪成32*32
transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转
transforms.ToTensor(), #维度转化 由32x32x3 ->3x32x32
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #R,G,B每层的归一化用到的均值和方差 即参数为变换过程,而非最终结果。
])
transform_test = transforms.Compose([
trans
Resnet 18 可跑完整pytroch代码
最新推荐文章于 2024-08-16 01:12:06 发布
本文提供了一段完整的Resnet18模型实现代码,并详细解析了每一步操作,帮助读者深入理解PyTorch框架下Residual Network的构建过程。
摘要由CSDN通过智能技术生成