参考李沐的课件与教材,将重点部分做了注释
1. ndarray实现
%matplotlib inline
from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random
# 本函数已保存在 gluonbook 包中方便以后使用。
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
#print('indices=',indices)
random.shuffle(indices) # 样本的读取顺序是随机的。
for i in range(0, num_examples, batch_size):
j = nd.array(indices[i: min(i + batch_size, num_examples)])
yield features.take(j), labels.take(j) # take 函数根据索引返回对应元素。yield生成器用法
def use_svg_display():
# 用矢量图显示。
display.set_matplotlib_formats('svg')
def set_figsize(figsize=(3.5, 2.5)):
use_svg_display()
# 设置图的尺寸。
plt.rcParams['figure.figsize'] = figsize
def linreg(X, w, b): # 本函数已保存在 gluonbook 包中方便以后使用。
return nd.dot(X, w) + b
def sq