莫烦python视频学习笔记 视频链接https://www.bilibili.com/video/BV1Vx411j7kT?from=search&seid=3065687802317837578
优化器学习
代码:
import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# import os 这里是为了防止报错加的
# hyper parameters 超参数
LR = 0.01
BATCH_SIZE = 32
EPOCH = 12
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))
# plot
plt.scatter(x.numpy(), y.numpy())
plt.