优化器
pytorch的优化器:管理并更新模型中科学系参数的值,使得模型输出更接近真实标签
导数:函数在制定坐标轴上的变化
方向导数:指定方向上的变化率
梯度:一个向量,方向为方向导数取得最大值的方向
基本属性
defaults 优化器超参数
state 参数的缓存 如momentum的缓存
pram_groups 管理的参数组
_step_count 记录更新次数,学习率调整中使用
基本方法
zero_grad() 清空所管理参数的梯度
pytorch特性:张量梯度不自动清零
step() 执行一步更新
add_param_group() 添加参数组
state_dict() 获取优化器当前状态信息字典
load_state_dict() 加载状态信息字典
# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert, set_seed
set_seed(1) # 设置随机种子
rmb_label = {
"1": 0, "100": 1}
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 数据 ============================
split_dir = os.path.join("data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.RandomGrayscale(p=0.8),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm