1. 四种优化器效果展示代码
import torch
import matplotlib.pyplot as plt
import torch.utils.data as Data
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# hyper parameters
LR = 0.01
BATCH_SIZE = 32
EPOCH = 12
# generate data
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.2 * torch.normal(torch.zeros(*x.size()))
torch_data = Data.TensorDataset(x, y)
# for mini batch
loader = Data.DataLoader(
dataset=torch_data,
batch_size=BATCH_SIZE