以下代码参考:
https://blog.csdn.net/wsp_1138886114/article/details/83717787
torchvision: PyTorch框架中的包
.
├── torchvision.datasets 自带的数据集
| ├── MNIST、COCO(图像标注/目标检测)、CIFAR10/CIFAR100
| └── LSUN Classification、ImageFolder、Imagenet-12、STL10
|
├── torchvision.models 自带的当前流行模型
| └──AlexNet、VGG、ResNet、SqueezeNet、DenseNet
|
├── torchvision.transforms 数据处理
| └── transforms.Compose(transforms) 一系列的transforms 操作
| └── data augmentation 包含resize、crop等常见数据增强操作
| ├──transformas.py
| └──functional.py
└─── torchvision.utils
├── torchvision.utils.make_grid
└── torchvision.utils.save_image
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
import os
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
"""
在每个epoch开始时都需要如下更新:
scheduler.step() 模型调度
model.train(True) 设置模型状态为训练状态
optimizer.zero_grad() 将网络中的所有梯度置0
outputs = model(inputs) 数据输入:网络的前向传播了
torch.max(outputs.data, 1) 模型预测该样本属于哪个类别,torch.max(tensor格式,每一行的最大值)
loss = criterion(outputs, labels) 输出outputs和原labels作为loss函数的输入就可以得到损失
loss.backward() 回传损失(只训练时用)
optimizer.step() 更新参数
"""
since = time.time()
best_model_wts = model.state_dict()
best_acc = 0.0
for epoch in range(num_epochs):
print